Skip to content

Commit c722817

Browse files
mridul-sahuOrbax Authors
authored andcommitted
Move temporary path cleanup to a separate thread and not block checkpoint manager init. We now wait for it when we try to do the first save post init.
PiperOrigin-RevId: 842605753
1 parent a13d078 commit c722817

File tree

3 files changed

+13
-29
lines changed

3 files changed

+13
-29
lines changed

checkpoint/orbax/checkpoint/_src/path/temporary_paths.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,3 @@ async def cleanup_temporary_paths(
156156
await asyncio.gather(
157157
*[async_path.rmtree(tmp_path.get()) for tmp_path in tmp_paths]
158158
)
159-
multihost.sync_global_processes(
160-
multihost.unique_barrier_key(
161-
'cleanup_tmp_dirs',
162-
prefix=multiprocessing_options.barrier_sync_key_prefix,
163-
),
164-
timeout=multihost.coordination_timeout(),
165-
processes=multiprocessing_options.active_processes,
166-
)

checkpoint/orbax/checkpoint/checkpoint_manager.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from orbax.checkpoint import checkpoint_args
3535
from orbax.checkpoint import options as options_lib
3636
from orbax.checkpoint import utils
37-
from orbax.checkpoint._src import asyncio_utils
3837
from orbax.checkpoint._src import threading as threading_lib
3938
from orbax.checkpoint._src.checkpoint_managers import policy_checkpoint_info
4039
from orbax.checkpoint._src.checkpoint_managers import preservation_policy as preservation_policy_lib
@@ -858,8 +857,9 @@ def __init__(
858857

859858

860859
# Cleanup directories from previous runs that may not have been finalized.
860+
self._cleanup_tmp_directory_future = None
861861
if self._options.cleanup_tmp_directories:
862-
asyncio_utils.run_sync(
862+
self._cleanup_tmp_directory_future = future.CommitFuture(
863863
temporary_paths.cleanup_temporary_paths(
864864
self._directory,
865865
multiprocessing_options=self._options.multiprocessing_options,
@@ -928,6 +928,14 @@ def __init__(
928928
self,
929929
)
930930

931+
def _maybe_await_cleanup_tmp_directory(self):
932+
if self._cleanup_tmp_directory_future is None:
933+
return
934+
935+
self._cleanup_tmp_directory_future.result()
936+
# Reset the future to None to avoid waiting for cleanup again.
937+
self._cleanup_tmp_directory_future = None
938+
931939
def _configure_checkpointer_common(
932940
self,
933941
handler: CompositeCheckpointHandler,
@@ -1388,7 +1396,8 @@ def save(
13881396
self._validate_args(items, args)
13891397
if not force and not self.should_save(step):
13901398
return False
1391-
1399+
# Wait for any ongoing temporary path cleanup before starting the save.
1400+
self._maybe_await_cleanup_tmp_directory()
13921401
multihost.sync_global_processes(
13931402
multihost.unique_barrier_key(
13941403
'CheckpointManager:save_start',

checkpoint/orbax/checkpoint/checkpoint_manager_test.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -928,20 +928,6 @@ def test_save_preempted(self, enable_async, prefix):
928928
cleanup_tmp_directories=True,
929929
),
930930
) as manager:
931-
# Temp checkpoints cleaned up at creation.
932-
self.assertFalse(tmp_dir.exists())
933-
self.assertSameElements([], manager.all_steps())
934-
935-
# Sync to check directories before a new tmp dir is created.
936-
test_utils.sync_global_processes('test_check_dirs_after_cleanup')
937-
938-
tmp_dir = test_utils.save_fake_tmp_dir(
939-
self.directory, 0, 'params', subdirs=['subdir'], step_prefix=prefix
940-
)
941-
self.assertTrue(tmp_dir.exists())
942-
tmp_dir_items = list(tmp_dir.iterdir())
943-
self.assertLen(tmp_dir_items, 1)
944-
self.assertIn('subdir', tmp_dir_items[0].name)
945931
self.assertSameElements(
946932
[], manager.all_steps()
947933
) # Only picks up finalized.
@@ -1130,11 +1116,8 @@ def test_save_preempted_gcs(self, enable_async, is_gcs_path):
11301116
enable_async_checkpointing=enable_async,
11311117
),
11321118
) as manager:
1133-
# Temp checkpoints cleaned up at creation.
1134-
self.assertFalse(tmp_dir.exists())
1135-
self.assertFalse((tmp_dir / 'subdir').exists())
11361119
self.assertSameElements([], manager.all_steps())
1137-
1120+
manager._maybe_await_cleanup_tmp_directory()
11381121
# Sync to check directories before a new tmp dir is created.
11391122
test_utils.sync_global_processes('test_check_dirs_after_cleanup')
11401123

0 commit comments

Comments
 (0)