Skip to content

Commit d6a6ebd

Browse files
mridul-sahuOrbax Authors
authored andcommitted
OSS async_utils_test.py and use operationId generator instead of passing operation id around.
PiperOrigin-RevId: 842618298
1 parent 7da0ece commit d6a6ebd

File tree

3 files changed

+132
-8
lines changed

3 files changed

+132
-8
lines changed

checkpoint/orbax/checkpoint/_src/path/atomicity.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,6 @@ def create_all_async(
675675
*,
676676
multiprocessing_options: options_lib.MultiprocessingOptions | None = None,
677677
subdirectories: Sequence[str] | None = None,
678-
operation_id: str | None = None,
679678
) -> future.Future:
680679
"""Creates all temporary paths in parallel asynchronously.
681680
@@ -688,8 +687,6 @@ def create_all_async(
688687
subdirectories: Sequence of subdirectories to create under `paths`. If not
689688
provided, no subdirectories will be created. The same set of
690689
subdirectories will be created under each path in `paths`.
691-
operation_id: The operation id to use for the barrier keys. If None, the
692-
current operation id is used.
693690
694691
Returns:
695692
A future that which sends the completion signals when all paths are created.
@@ -720,7 +717,6 @@ def create_all_async(
720717
),
721718
send_signals=completion_signals,
722719
timeout_secs=multihost.coordination_timeout(),
723-
operation_id=operation_id,
724720
)
725721
future.AwaitableSignalsContract.add_to_awaitable_signals_contract(
726722
completion_signals

checkpoint/orbax/checkpoint/experimental/v1/_src/path/async_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ async def await_creation(self) -> Path:
5656
def start_async_mkdir(
5757
path: atomicity_types.TemporaryPath,
5858
subdirectories: Iterable[str] = (),
59-
operation_id: str | None = None,
6059
) -> PathAwaitingCreation:
6160
"""Starts async directory creation on a TemporaryPath.
6261
@@ -76,8 +75,6 @@ def start_async_mkdir(
7675
Args:
7776
path: The path to create. May be an instance of `TemporaryPath`.
7877
subdirectories: A sequence of subdirectories to create under `path`.
79-
operation_id: The operation id to use for the barrier keys. If None, the
80-
current operation id is used.
8178
8279
Returns:
8380
A PathAwaitingCreation object.
@@ -97,6 +94,5 @@ def start_async_mkdir(
9794
completion_signals=completion_signals,
9895
multiprocessing_options=context.multiprocessing_options.v0(),
9996
subdirectories=[name for name in subdirectories],
100-
operation_id=operation_id,
10197
)
10298
return _PathAwaitingCreation(path.get(), f)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright 2025 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import time
17+
import unittest
18+
from unittest import mock
19+
from absl.testing import absltest
20+
from etils import epath
21+
from orbax.checkpoint._src.path import async_path
22+
from orbax.checkpoint._src.path import atomicity
23+
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
24+
from orbax.checkpoint.experimental.v1._src.path import async_utils
25+
26+
27+
class AsyncUtilsTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase):
28+
29+
def setUp(self):
30+
super().setUp()
31+
32+
orig_create_paths = atomicity._create_paths
33+
34+
async def _sleep_and_create_paths(
35+
*args,
36+
**kwargs,
37+
):
38+
await asyncio.sleep(1)
39+
return await orig_create_paths(*args, **kwargs)
40+
41+
self.enter_context(
42+
mock.patch.object(
43+
atomicity, '_create_paths', new=_sleep_and_create_paths
44+
)
45+
)
46+
self.directory = epath.Path(self.create_tempdir().full_path)
47+
48+
async def assertExists(self, path: epath.Path):
49+
self.assertTrue(await async_path.exists(path))
50+
51+
async def assertNotExists(self, path: epath.Path):
52+
self.assertFalse(await async_path.exists(path))
53+
54+
def assertBetween(self, a, b, c):
55+
self.assertGreater(b, a)
56+
self.assertGreater(c, b)
57+
58+
async def test_async_mkdir(self):
59+
await context_lib.synchronize_next_operation_id()
60+
tmpdir = atomicity.AtomicRenameTemporaryPath(
61+
self.directory / 'tmp', self.directory / 'final'
62+
)
63+
start = time.time()
64+
p = async_utils.start_async_mkdir(tmpdir)
65+
await p.await_creation()
66+
self.assertBetween(1, time.time() - start, 2)
67+
await self.assertExists(self.directory / 'tmp')
68+
69+
async def test_async_mkdir_with_subdirectories(self):
70+
await context_lib.synchronize_next_operation_id()
71+
tmpdir = atomicity.AtomicRenameTemporaryPath(
72+
self.directory / 'tmp', self.directory / 'final'
73+
)
74+
start = time.time()
75+
p = async_utils.start_async_mkdir(tmpdir, ['a', 'b'])
76+
await p.await_creation()
77+
self.assertBetween(1, time.time() - start, 2)
78+
await self.assertExists(self.directory / 'tmp')
79+
await self.assertExists(self.directory / 'tmp' / 'a')
80+
await self.assertExists(self.directory / 'tmp' / 'b')
81+
82+
async def test_async_mkdir_sequential(self):
83+
tmpdir1 = atomicity.AtomicRenameTemporaryPath(
84+
self.directory / 'tmp1', self.directory / 'final1'
85+
)
86+
tmpdir2 = atomicity.AtomicRenameTemporaryPath(
87+
self.directory / 'tmp2', self.directory / 'final2'
88+
)
89+
start = time.time()
90+
await context_lib.synchronize_next_operation_id()
91+
p1 = async_utils.start_async_mkdir(tmpdir1)
92+
await context_lib.synchronize_next_operation_id()
93+
p2 = async_utils.start_async_mkdir(tmpdir2)
94+
await p1.await_creation()
95+
await p2.await_creation()
96+
# Awaiting sequentially does not take any longer than awaiting in parallel,
97+
# because the operations are already in progress.
98+
self.assertBetween(1, time.time() - start, 2)
99+
await self.assertExists(self.directory / 'tmp1')
100+
await self.assertExists(self.directory / 'tmp2')
101+
102+
async def test_async_mkdir_parallel(self):
103+
tmpdir1 = atomicity.AtomicRenameTemporaryPath(
104+
self.directory / 'tmp1', self.directory / 'final1'
105+
)
106+
tmpdir2 = atomicity.AtomicRenameTemporaryPath(
107+
self.directory / 'tmp2', self.directory / 'final2'
108+
)
109+
start = time.time()
110+
await context_lib.synchronize_next_operation_id()
111+
p1 = async_utils.start_async_mkdir(tmpdir1)
112+
await context_lib.synchronize_next_operation_id()
113+
p2 = async_utils.start_async_mkdir(tmpdir2)
114+
await asyncio.gather(p1.await_creation(), p2.await_creation())
115+
self.assertBetween(1, time.time() - start, 2)
116+
await self.assertExists(self.directory / 'tmp1')
117+
await self.assertExists(self.directory / 'tmp2')
118+
119+
async def test_async_mkdir_with_delayed_wait(self):
120+
await context_lib.synchronize_next_operation_id()
121+
tmpdir = atomicity.AtomicRenameTemporaryPath(
122+
self.directory / 'tmp', self.directory / 'final'
123+
)
124+
p = async_utils.start_async_mkdir(tmpdir)
125+
await self.assertNotExists(self.directory / 'tmp')
126+
await asyncio.sleep(1)
127+
await p.await_creation()
128+
await self.assertExists(self.directory / 'tmp')
129+
130+
131+
if __name__ == '__main__':
132+
absltest.main()

0 commit comments

Comments
 (0)