Skip to content

Commit bb5a0f4

Browse files
mridul-sahuOrbax Authors
authored andcommitted
Add lightweight_initialize to skip reading metadata on checkpoint info load on init.
PiperOrigin-RevId: 842189498
1 parent 07ff391 commit bb5a0f4

File tree

2 files changed

+80
-7
lines changed

2 files changed

+80
-7
lines changed

checkpoint/orbax/checkpoint/checkpoint_manager.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,12 @@ class CheckpointManagerOptions:
364364
supposed to be created per process. This is used to support async
365365
directory creation. If True, `multiprocessing_options.primary_host` must be
366366
None.
367+
lightweight_initialize: If True, checkpoint step metadata is not
368+
read on
369+
CheckpointManager initialization during checkpoint info loading. This is
370+
useful to improve init performance
371+
when there are O(1k) or more existing checkpoint step present and checkpoint
372+
info properties like `time` and `metrics` are not needed.
367373
"""
368374

369375
save_interval_steps: int = 1
@@ -404,6 +410,7 @@ class CheckpointManagerOptions:
404410
# TODO(b/428061876) Remove this option.
405411
enable_should_save_is_saving_in_progress_check: bool = True
406412
enable_per_process_directory_creation: bool = False
413+
lightweight_initialize: bool = False
407414

408415
def __post_init__(self):
409416
step_name_format_single_host_load_and_broadcast = (
@@ -879,7 +886,9 @@ def __init__(
879886
)
880887

881888
self._checkpoints = checkpoint_info.CheckpointInfos(
882-
self._load_checkpoint_infos()
889+
self._load_checkpoint_infos(
890+
skip_metadata_read=self._options.lightweight_initialize
891+
)
883892
)
884893

885894
self._metadata_dir = self.directory / METADATA_ITEM_NAME
@@ -1764,11 +1773,17 @@ def _track_best(self):
17641773
"""Returns true if we should track the best checkpoints by given metric."""
17651774
return self._options.best_fn is not None
17661775

1767-
def _load_checkpoint_infos(self) -> List[CheckpointInfo]:
1776+
def _load_checkpoint_infos(
1777+
self, skip_metadata_read=False
1778+
) -> List[CheckpointInfo]:
17681779
"""Loads a list of CheckpointInfo for existing checkpoints.
17691780
17701781
If none are present, returns empty list.
17711782
1783+
Args:
1784+
skip_metadata_read: If True, will not read metadata from disk to build
1785+
checkpoint infos.
1786+
17721787
Returns:
17731788
a list of CheckpointInfo, sorted by increasing step.
17741789
"""
@@ -1778,11 +1793,18 @@ def _load_checkpoint_infos(self) -> List[CheckpointInfo]:
17781793
step_metadatas = self._step_name_format.find_all(self.directory)
17791794

17801795
def build_checkpoint_info(step_metadata):
1781-
return CheckpointInfo(
1782-
step=step_metadata.step,
1783-
time=step_metadata.commit_timestamp,
1784-
metrics=self.metrics(step_metadata.step),
1785-
)
1796+
if skip_metadata_read:
1797+
return CheckpointInfo(
1798+
step=step_metadata.step,
1799+
time=datetime.datetime.min,
1800+
metrics=None,
1801+
)
1802+
else:
1803+
return CheckpointInfo(
1804+
step=step_metadata.step,
1805+
time=step_metadata.commit_timestamp,
1806+
metrics=self.metrics(step_metadata.step),
1807+
)
17861808

17871809
with concurrent.futures.ThreadPoolExecutor() as executor:
17881810
checkpoint_infos = list(

checkpoint/orbax/checkpoint/checkpoint_manager_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,57 @@ def test_all_steps_reload(self):
530530
manager.close()
531531
new_manager.close()
532532

533+
def test_lightweight_intialize_loads_without_metadata_read(self):
534+
options = CheckpointManagerOptions(
535+
best_fn=lambda metrics: metrics['loss'],
536+
best_mode='min',
537+
)
538+
with CheckpointManager(
539+
self.directory, item_names=('params',), options=options
540+
) as manager:
541+
self.assertTrue(
542+
manager.save(
543+
0,
544+
args=args.Composite(params=args.PyTreeSave(self.pytree)),
545+
metrics={'loss': 1.0},
546+
)
547+
)
548+
self.assertTrue(
549+
manager.save(
550+
1,
551+
args=args.Composite(params=args.PyTreeSave(self.pytree)),
552+
metrics={'loss': 0.5},
553+
)
554+
)
555+
self.wait_if_async(manager)
556+
557+
light_options = CheckpointManagerOptions(
558+
lightweight_initialize=True,
559+
best_fn=lambda metrics: metrics['loss'],
560+
best_mode='min',
561+
)
562+
with CheckpointManager(
563+
self.directory, item_names=('params',), options=light_options
564+
) as light_manager:
565+
self.assertLen(light_manager._checkpoints, 2)
566+
for ckpt in light_manager._checkpoints:
567+
self.assertIsNone(ckpt.metrics)
568+
self.assertEqual(ckpt.time, datetime.datetime.min)
569+
570+
# reload should load metrics
571+
light_manager.reload()
572+
self.assertLen(light_manager._checkpoints, 2)
573+
self.assertIsNotNone(light_manager._checkpoints[0].metrics)
574+
self.assertIsNotNone(light_manager._checkpoints[1].metrics)
575+
self.assertEqual(light_manager._checkpoints[0].metrics['loss'], 1.0)
576+
self.assertEqual(light_manager._checkpoints[1].metrics['loss'], 0.5)
577+
self.assertNotEqual(
578+
light_manager._checkpoints[0].time, datetime.datetime.min
579+
)
580+
self.assertNotEqual(
581+
light_manager._checkpoints[1].time, datetime.datetime.min
582+
)
583+
533584
@parameterized.parameters((False, 1), (True, 2))
534585
def test_latest_step(self, enable_async, save_interval_steps):
535586
options = CheckpointManagerOptions(

0 commit comments

Comments
 (0)