Skip to content

Commit 86470b6

Browse files
mridul-sahuOrbax Authors
authored andcommitted
OSS async_utils_test.py and respect operation if passed to create_all_async when updating await-able signal contract.
PiperOrigin-RevId: 842618298
1 parent 7da0ece commit 86470b6

File tree

3 files changed

+137
-8
lines changed

3 files changed

+137
-8
lines changed

checkpoint/orbax/checkpoint/_src/futures/future.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,27 +102,32 @@ def get_awaitable_signals_from_contract(
102102
def add_to_awaitable_signals_contract(
103103
cls,
104104
signals: Sequence[synchronization.HandlerAwaitableSignal],
105+
operation_id: str | None = None,
105106
):
106107
"""Adds awaitable signals to `AWAITABLE_SIGNALS_CONTRACT` for lower checkpointing layers to wait on.
107108
108-
These signals are added to the list of awaitable signals for the current
109-
operation id in `OperationIdGenerator`.
109+
These signals are added to the list of awaitable signals for the given
110+
operaion id or the current operation id in `OperationIdGenerator`.
110111
111112
Args:
112113
signals: The signals to add to the list of awaitable signals.
114+
operation_id: The operation id to use for the barrier keys. If None, the
115+
current operation id is used.
113116
"""
114117
if not signals:
115118
return
116-
117-
current_signals = list(cls.get_awaitable_signals_from_contract())
119+
operation_id = (
120+
operation_id
121+
or synchronization.OperationIdGenerator.get_current_operation_id()
122+
)
123+
current_signals = list(
124+
cls.get_awaitable_signals_from_contract(operation_id=operation_id)
125+
)
118126
current_signals.extend(signals)
119127
keys = ','.join(
120128
[current_signal.value for current_signal in current_signals]
121129
)
122130
client = signaling_client.get_signaling_client()
123-
operation_id = (
124-
synchronization.OperationIdGenerator.get_current_operation_id()
125-
)
126131
barrier_key = cls.get_unique_awaitable_singal_key(
127132
synchronization.HandlerAwaitableSignal.AWAITABLE_SIGNALS_CONTRACT,
128133
operation_id,

checkpoint/orbax/checkpoint/_src/path/atomicity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ def create_all_async(
723723
operation_id=operation_id,
724724
)
725725
future.AwaitableSignalsContract.add_to_awaitable_signals_contract(
726-
completion_signals
726+
completion_signals, operation_id=operation_id
727727
)
728728

729729
# Sync to enusre that all hosts have the same awaitable signals contract.
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2025 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import time
17+
import unittest
18+
from unittest import mock
19+
from absl.testing import absltest
20+
from etils import epath
21+
from orbax.checkpoint._src.path import async_path
22+
from orbax.checkpoint._src.path import atomicity
23+
from orbax.checkpoint.experimental.v1._src.path import async_utils
24+
25+
26+
class AsyncUtilsTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase):
27+
28+
def setUp(self):
29+
super().setUp()
30+
31+
orig_create_paths = atomicity._create_paths
32+
33+
async def _sleep_and_create_paths(
34+
*args,
35+
**kwargs,
36+
):
37+
await asyncio.sleep(1)
38+
return await orig_create_paths(*args, **kwargs)
39+
40+
self.enter_context(
41+
mock.patch.object(
42+
atomicity, '_create_paths', new=_sleep_and_create_paths
43+
)
44+
)
45+
self.directory = epath.Path(self.create_tempdir().full_path)
46+
47+
async def assertExists(self, path: epath.Path):
48+
self.assertTrue(await async_path.exists(path))
49+
50+
async def assertNotExists(self, path: epath.Path):
51+
self.assertFalse(await async_path.exists(path))
52+
53+
def assertBetween(self, a, b, c):
54+
self.assertGreater(b, a)
55+
self.assertGreater(c, b)
56+
57+
async def test_async_mkdir(self):
58+
tmpdir = atomicity.AtomicRenameTemporaryPath(
59+
self.directory / 'tmp', self.directory / 'final'
60+
)
61+
start = time.time()
62+
p = async_utils.start_async_mkdir(tmpdir, operation_id='op1')
63+
await p.await_creation()
64+
self.assertBetween(1, time.time() - start, 2)
65+
await self.assertExists(self.directory / 'tmp')
66+
67+
async def test_async_mkdir_with_subdirectories(self):
68+
tmpdir = atomicity.AtomicRenameTemporaryPath(
69+
self.directory / 'tmp', self.directory / 'final'
70+
)
71+
start = time.time()
72+
p = async_utils.start_async_mkdir(tmpdir, ['a', 'b'], operation_id='op1')
73+
await p.await_creation()
74+
self.assertBetween(1, time.time() - start, 2)
75+
await self.assertExists(self.directory / 'tmp')
76+
await self.assertExists(self.directory / 'tmp' / 'a')
77+
await self.assertExists(self.directory / 'tmp' / 'b')
78+
79+
async def test_async_mkdir_sequential(self):
80+
tmpdir1 = atomicity.AtomicRenameTemporaryPath(
81+
self.directory / 'tmp1', self.directory / 'final1'
82+
)
83+
tmpdir2 = atomicity.AtomicRenameTemporaryPath(
84+
self.directory / 'tmp2', self.directory / 'final2'
85+
)
86+
start = time.time()
87+
p1 = async_utils.start_async_mkdir(tmpdir1, operation_id='op1')
88+
p2 = async_utils.start_async_mkdir(tmpdir2, operation_id='op2')
89+
await p1.await_creation()
90+
await p2.await_creation()
91+
# Awaiting sequentially does not take any longer than awaiting in parallel,
92+
# because the operations are already in progress.
93+
self.assertBetween(1, time.time() - start, 2)
94+
await self.assertExists(self.directory / 'tmp1')
95+
await self.assertExists(self.directory / 'tmp2')
96+
97+
async def test_async_mkdir_parallel(self):
98+
tmpdir1 = atomicity.AtomicRenameTemporaryPath(
99+
self.directory / 'tmp1', self.directory / 'final1'
100+
)
101+
tmpdir2 = atomicity.AtomicRenameTemporaryPath(
102+
self.directory / 'tmp2', self.directory / 'final2'
103+
)
104+
start = time.time()
105+
p1 = async_utils.start_async_mkdir(tmpdir1, operation_id='op1')
106+
p2 = async_utils.start_async_mkdir(tmpdir2, operation_id='op2')
107+
await asyncio.gather(p1.await_creation(), p2.await_creation())
108+
self.assertBetween(1, time.time() - start, 2)
109+
await self.assertExists(self.directory / 'tmp1')
110+
await self.assertExists(self.directory / 'tmp2')
111+
112+
async def test_async_mkdir_with_delayed_wait(self):
113+
tmpdir = atomicity.AtomicRenameTemporaryPath(
114+
self.directory / 'tmp', self.directory / 'final'
115+
)
116+
p = async_utils.start_async_mkdir(tmpdir, operation_id='op1')
117+
await self.assertNotExists(self.directory / 'tmp')
118+
await asyncio.sleep(1)
119+
await p.await_creation()
120+
await self.assertExists(self.directory / 'tmp')
121+
122+
123+
if __name__ == '__main__':
124+
absltest.main()

0 commit comments

Comments
 (0)