Skip to content

Commit ebf970f

Browse files
committed
make SequentialOrderManager thread-local and cached queues, devices global
1 parent c3b3fb0 commit ebf970f

File tree

4 files changed

+75
-100
lines changed

4 files changed

+75
-100
lines changed

dpctl/_sycl_device_factory.pyx

Lines changed: 36 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
# cython: linetrace=True
2020
# cython: freethreading_compatible = True
2121

22-
""" This module implements several device creation helper functions:
22+
"""
23+
This module implements several device creation helper functions:
2324
2425
- wrapper functions to create a SyclDevice from the standard SYCL
2526
device selector classes.
@@ -47,7 +48,7 @@ from ._backend cimport ( # noqa: E211
4748
_device_type,
4849
)
4950

50-
from contextvars import ContextVar
51+
import threading
5152

5253
from ._sycl_device import SyclDeviceCreationError
5354
from .enum_types import backend_type
@@ -287,7 +288,8 @@ cpdef int get_num_devices(
287288

288289

289290
cpdef cpp_bool has_cpu_devices():
290-
""" A helper function to check if there are any SYCL CPU devices available.
291+
"""
292+
A helper function to check if there are any SYCL CPU devices available.
291293
292294
Returns:
293295
bool:
@@ -299,7 +301,8 @@ cpdef cpp_bool has_cpu_devices():
299301

300302

301303
cpdef cpp_bool has_gpu_devices():
302-
""" A helper function to check if there are any SYCL GPU devices available.
304+
"""
305+
A helper function to check if there are any SYCL GPU devices available.
303306
304307
Returns:
305308
bool:
@@ -311,7 +314,8 @@ cpdef cpp_bool has_gpu_devices():
311314

312315

313316
cpdef cpp_bool has_accelerator_devices():
314-
""" A helper function to check if there are any SYCL Accelerator devices
317+
"""
318+
A helper function to check if there are any SYCL Accelerator devices
315319
available.
316320
317321
Returns:
@@ -326,7 +330,8 @@ cpdef cpp_bool has_accelerator_devices():
326330

327331

328332
cpdef SyclDevice select_accelerator_device():
329-
"""A wrapper for ``sycl::device{sycl::accelerator_selector_v}`` constructor.
333+
"""
334+
A wrapper for ``sycl::device{sycl::accelerator_selector_v}`` constructor.
330335
331336
Returns:
332337
dpctl.SyclDevice:
@@ -348,7 +353,8 @@ cpdef SyclDevice select_accelerator_device():
348353

349354

350355
cpdef SyclDevice select_cpu_device():
351-
"""A wrapper for ``sycl::device{sycl::cpu_selector_v}`` constructor.
356+
"""
357+
A wrapper for ``sycl::device{sycl::cpu_selector_v}`` constructor.
352358
353359
Returns:
354360
dpctl.SyclDevice:
@@ -370,7 +376,8 @@ cpdef SyclDevice select_cpu_device():
370376

371377

372378
cpdef SyclDevice select_default_device():
373-
"""A wrapper for ``sycl::device{sycl::default_selector_v}`` constructor.
379+
"""
380+
A wrapper for ``sycl::device{sycl::default_selector_v}`` constructor.
374381
375382
Returns:
376383
dpctl.SyclDevice:
@@ -392,7 +399,8 @@ cpdef SyclDevice select_default_device():
392399

393400

394401
cpdef SyclDevice select_gpu_device():
395-
"""A wrapper for ``sycl::device{sycl::gpu_selector_v}`` constructor.
402+
"""
403+
A wrapper for ``sycl::device{sycl::gpu_selector_v}`` constructor.
396404
397405
Returns:
398406
dpctl.SyclDevice:
@@ -415,21 +423,23 @@ cpdef SyclDevice select_gpu_device():
415423

416424
cdef class _DefaultDeviceCache:
417425
cdef dict __device_map__
426+
cdef object _cache_lock
418427

419428
def __cinit__(self):
420429
self.__device_map__ = dict()
421-
422-
cdef get_or_create(self):
423-
"""Return instance of SyclDevice and indicator if cache
424-
has been modified"""
425-
key = 0
426-
if key in self.__device_map__:
427-
return self.__device_map__[key], False
428-
dev = select_default_device()
429-
self.__device_map__[key] = dev
430-
return dev, True
431-
432-
cdef _update_map(self, dev_map):
430+
self._cache_lock = threading.Lock()
431+
432+
def get_or_create(self):
433+
"""Return cached default SyclDevice, creating it if needed."""
434+
with self._cache_lock:
435+
key = 0
436+
if key in self.__device_map__:
437+
return self.__device_map__[key]
438+
dev = select_default_device()
439+
self.__device_map__[key] = dev
440+
return dev
441+
442+
def _update_map(self, dev_map):
433443
self.__device_map__.update(dev_map)
434444

435445
def __copy__(self):
@@ -439,37 +449,17 @@ cdef class _DefaultDeviceCache:
439449
return _copy
440450

441451

442-
# no default, as would share a single mutable instance across threads and
443-
# concurrent access to the cache would not be thread-safe. Using ContextVar
444-
# without a default ensures each context gets its own instance.
445-
_global_default_device_cache = ContextVar(
446-
"global_default_device_cache",
447-
)
448-
449-
450-
cdef _DefaultDeviceCache _get_default_device_cache():
451-
"""
452-
Factory function to get or create a default device cache for the current
453-
context
454-
"""
455-
try:
456-
return _global_default_device_cache.get()
457-
except LookupError:
458-
cache = _DefaultDeviceCache()
459-
_global_default_device_cache.set(cache)
460-
return cache
452+
# all threads share the same cached default
453+
_global_default_device_cache = _DefaultDeviceCache()
461454

462455

463456
cpdef SyclDevice _cached_default_device():
464-
"""Returns a cached device selected by default selector.
457+
"""
458+
Returns a cached device selected by default selector.
465459
466460
Returns:
467461
dpctl.SyclDevice:
468462
A cached default-selected SYCL device.
469463
470464
"""
471-
cdef _DefaultDeviceCache _cache = _get_default_device_cache()
472-
d_, changed_ = _cache.get_or_create()
473-
if changed_:
474-
_global_default_device_cache.set(_cache)
475-
return d_
465+
return _global_default_device_cache.get_or_create()

dpctl/_sycl_queue_manager.pyx

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# cython: freethreading_compatible = True
2121

2222
import logging
23-
from contextvars import ContextVar
23+
import threading
2424
from ._sycl_context cimport SyclContext
2525
from ._sycl_device cimport SyclDevice
2626

@@ -34,13 +34,14 @@ _logger = logging.getLogger(__name__)
3434

3535
cdef class _DeviceDefaultQueueCache:
3636
cdef dict __device_queue_map__
37+
cdef object _cache_lock
3738

3839
def __cinit__(self):
3940
self.__device_queue_map__ = dict()
41+
self._cache_lock = threading.Lock()
4042

4143
def get_or_create(self, key):
42-
"""Return instance of SyclQueue and indicator if cache
43-
has been modified"""
44+
"""Return cached SyclQueue for given key, creating it if needed."""
4445
if (
4546
isinstance(key, tuple)
4647
and len(key) == 2
@@ -57,14 +58,15 @@ cdef class _DeviceDefaultQueueCache:
5758
ctx_dev = q.sycl_context, q.sycl_device
5859
else:
5960
raise TypeError
60-
if ctx_dev in self.__device_queue_map__:
61-
return self.__device_queue_map__[ctx_dev], False
62-
if q is None:
63-
q = SyclQueue(*ctx_dev)
64-
self.__device_queue_map__[ctx_dev] = q
65-
return q, True
66-
67-
cdef _update_map(self, dev_queue_map):
61+
with self._cache_lock:
62+
if ctx_dev in self.__device_queue_map__:
63+
return self.__device_queue_map__[ctx_dev]
64+
if q is None:
65+
q = SyclQueue(*ctx_dev)
66+
self.__device_queue_map__[ctx_dev] = q
67+
return q
68+
69+
def _update_map(self, dev_queue_map):
6870
self.__device_queue_map__.update(dev_queue_map)
6971

7072
def __copy__(self):
@@ -75,29 +77,13 @@ cdef class _DeviceDefaultQueueCache:
7577
return _copy
7678

7779

78-
# no default, as would share a single mutable instance across threads and
79-
# concurrent access to the cache would not be thread-safe. Using ContextVar
80-
# without a default ensures each context gets its own instance.
81-
_global_device_queue_cache = ContextVar(
82-
"global_device_queue_cache",
83-
)
84-
85-
86-
cdef _DeviceDefaultQueueCache _get_device_queue_cache():
87-
"""
88-
Factory function to get or create a default device queue cache for the
89-
current context
90-
"""
91-
try:
92-
return _global_device_queue_cache.get()
93-
except LookupError:
94-
cache = _DeviceDefaultQueueCache()
95-
_global_device_queue_cache.set(cache)
96-
return cache
80+
# all threads share the same cached default
81+
_global_device_queue_cache = _DeviceDefaultQueueCache()
9782

9883

9984
cpdef object get_device_cached_queue(object key):
100-
"""Returns a cached queue associated with given device.
85+
"""
86+
Returns a cached queue associated with given device.
10187
10288
Args:
10389
key : Either a 2-tuple consisting of a :class:`dpctl.SyclContext` and
@@ -112,8 +98,4 @@ cpdef object get_device_cached_queue(object key):
11298
TypeError: If the input key is not one of the accepted types.
11399
114100
"""
115-
_cache = _get_device_queue_cache()
116-
q_, changed_ = _cache.get_or_create(key)
117-
if changed_:
118-
_global_device_queue_cache.set(_cache)
119-
return q_
101+
return _global_device_queue_cache.get_or_create(key)

dpctl/tests/test_sycl_queue_manager.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,20 @@
2424
def test__DeviceDefaultQueueCache():
2525
import copy
2626

27-
from dpctl._sycl_queue_manager import _global_device_queue_cache as cache
28-
from dpctl._sycl_queue_manager import get_device_cached_queue
27+
from dpctl._sycl_queue_manager import (
28+
_global_device_queue_cache,
29+
get_device_cached_queue,
30+
)
2931

3032
try:
3133
d = dpctl.SyclDevice()
3234
except dpctl.SyclDeviceCreationError:
3335
pytest.skip("Could not create default device")
3436

3537
q1 = get_device_cached_queue(d)
36-
cache_copy = copy.copy(cache.get())
37-
q2, changed = cache_copy.get_or_create(d)
38+
cache_copy = copy.copy(_global_device_queue_cache)
39+
q2 = cache_copy.get_or_create(d)
3840

39-
assert not changed
4041
assert q1 == q2
4142
q3 = get_device_cached_queue(d.filter_string)
4243
assert q3 == q1

dpctl/utils/_order_manager.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import threading
12
import weakref
23
from collections import defaultdict
3-
from contextvars import ContextVar
44

55
from .._sycl_event import SyclEvent
66
from .._sycl_queue import SyclQueue
@@ -65,26 +65,28 @@ def __copy__(self):
6565

6666

6767
class SyclQueueToOrderManagerMap:
68-
"""Utility class used to ensure sequential ordering of offloaded tasks
69-
when passed to order manager."""
68+
"""
69+
Utility class used to ensure sequential ordering of offloaded tasks
70+
when passed to order manager.
71+
72+
Maintains a thread-local dictionary mapping SyclQueue instances to
73+
_SequentialOrderManager instances.
74+
"""
7075

7176
def __init__(self):
72-
self._map = ContextVar(
73-
"global_order_manager_map",
74-
# no default to avoid sharing a single defaultdict
75-
# across threads
76-
)
77+
# each thread gets its own dictionary of order managers
78+
self._tls = threading.local()
7779

7880
def _get_map(self):
7981
"""
80-
Factory method to get or create a default device queue cache for the
81-
current context
82+
Factory method to get or create an order manager map for the
83+
current thread.
8284
"""
8385
try:
84-
return self._map.get()
85-
except LookupError:
86+
return self._tls.order_manager_map
87+
except AttributeError:
8688
m = defaultdict(_SequentialOrderManager)
87-
self._map.set(m)
89+
self._tls.order_manager_map = m
8890
return m
8991

9092
def __getitem__(self, q: SyclQueue) -> _SequentialOrderManager:

0 commit comments

Comments
 (0)