1919# cython: linetrace=True
2020# cython: freethreading_compatible = True
2121
22- """ This module implements several device creation helper functions:
22+ """
23+ This module implements several device creation helper functions:
2324
2425 - wrapper functions to create a SyclDevice from the standard SYCL
2526 device selector classes.
@@ -47,7 +48,7 @@ from ._backend cimport ( # noqa: E211
4748 _device_type,
4849)
4950
50- from contextvars import ContextVar
51+ import threading
5152
5253from ._sycl_device import SyclDeviceCreationError
5354from .enum_types import backend_type
@@ -287,7 +288,8 @@ cpdef int get_num_devices(
287288
288289
289290cpdef cpp_bool has_cpu_devices():
290- """ A helper function to check if there are any SYCL CPU devices available.
291+ """
292+ A helper function to check if there are any SYCL CPU devices available.
291293
292294 Returns:
293295 bool:
@@ -299,7 +301,8 @@ cpdef cpp_bool has_cpu_devices():
299301
300302
301303cpdef cpp_bool has_gpu_devices():
302- """ A helper function to check if there are any SYCL GPU devices available.
304+ """
305+ A helper function to check if there are any SYCL GPU devices available.
303306
304307 Returns:
305308 bool:
@@ -311,7 +314,8 @@ cpdef cpp_bool has_gpu_devices():
311314
312315
313316cpdef cpp_bool has_accelerator_devices():
314- """ A helper function to check if there are any SYCL Accelerator devices
317+ """
318+ A helper function to check if there are any SYCL Accelerator devices
315319 available.
316320
317321 Returns:
@@ -326,7 +330,8 @@ cpdef cpp_bool has_accelerator_devices():
326330
327331
328332cpdef SyclDevice select_accelerator_device():
329- """ A wrapper for ``sycl::device{sycl::accelerator_selector_v}`` constructor.
333+ """
334+ A wrapper for ``sycl::device{sycl::accelerator_selector_v}`` constructor.
330335
331336 Returns:
332337 dpctl.SyclDevice:
@@ -348,7 +353,8 @@ cpdef SyclDevice select_accelerator_device():
348353
349354
350355cpdef SyclDevice select_cpu_device():
351- """ A wrapper for ``sycl::device{sycl::cpu_selector_v}`` constructor.
356+ """
357+ A wrapper for ``sycl::device{sycl::cpu_selector_v}`` constructor.
352358
353359 Returns:
354360 dpctl.SyclDevice:
@@ -370,7 +376,8 @@ cpdef SyclDevice select_cpu_device():
370376
371377
372378cpdef SyclDevice select_default_device():
373- """ A wrapper for ``sycl::device{sycl::default_selector_v}`` constructor.
379+ """
380+ A wrapper for ``sycl::device{sycl::default_selector_v}`` constructor.
374381
375382 Returns:
376383 dpctl.SyclDevice:
@@ -392,7 +399,8 @@ cpdef SyclDevice select_default_device():
392399
393400
394401cpdef SyclDevice select_gpu_device():
395- """ A wrapper for ``sycl::device{sycl::gpu_selector_v}`` constructor.
402+ """
403+ A wrapper for ``sycl::device{sycl::gpu_selector_v}`` constructor.
396404
397405 Returns:
398406 dpctl.SyclDevice:
@@ -415,21 +423,23 @@ cpdef SyclDevice select_gpu_device():
415423
416424cdef class _DefaultDeviceCache:
417425 cdef dict __device_map__
426+ cdef object _cache_lock
418427
419428 def __cinit__ (self ):
420429 self .__device_map__ = dict ()
421-
422- cdef get_or_create(self ):
423- """ Return instance of SyclDevice and indicator if cache
424- has been modified"""
425- key = 0
426- if key in self .__device_map__:
427- return self .__device_map__[key], False
428- dev = select_default_device()
429- self .__device_map__[key] = dev
430- return dev, True
431-
432- cdef _update_map(self , dev_map):
430+ self ._cache_lock = threading.Lock()
431+
432+ def get_or_create (self ):
433+ """ Return cached default SyclDevice, creating it if needed."""
434+ with self ._cache_lock:
435+ key = 0
436+ if key in self .__device_map__:
437+ return self .__device_map__[key]
438+ dev = select_default_device()
439+ self .__device_map__[key] = dev
440+ return dev
441+
442+ def _update_map (self , dev_map ):
433443 self .__device_map__.update(dev_map)
434444
435445 def __copy__ (self ):
@@ -439,37 +449,17 @@ cdef class _DefaultDeviceCache:
439449 return _copy
440450
441451
442- # no default, as would share a single mutable instance across threads and
443- # concurrent access to the cache would not be thread-safe. Using ContextVar
444- # without a default ensures each context gets its own instance.
445- _global_default_device_cache = ContextVar(
446- " global_default_device_cache" ,
447- )
448-
449-
450- cdef _DefaultDeviceCache _get_default_device_cache():
451- """
452- Factory function to get or create a default device cache for the current
453- context
454- """
455- try :
456- return _global_default_device_cache.get()
457- except LookupError :
458- cache = _DefaultDeviceCache()
459- _global_default_device_cache.set(cache)
460- return cache
452+ # all threads share the same cached default
453+ _global_default_device_cache = _DefaultDeviceCache()
461454
462455
463456cpdef SyclDevice _cached_default_device():
464- """ Returns a cached device selected by default selector.
457+ """
458+ Returns a cached device selected by default selector.
465459
466460 Returns:
467461 dpctl.SyclDevice:
468462 A cached default-selected SYCL device.
469463
470464 """
471- cdef _DefaultDeviceCache _cache = _get_default_device_cache()
472- d_, changed_ = _cache.get_or_create()
473- if changed_:
474- _global_default_device_cache.set(_cache)
475- return d_
465+ return _global_default_device_cache.get_or_create()
0 commit comments