@@ -364,6 +364,12 @@ class CheckpointManagerOptions:
364364 supposed to be created per process. This is used to support async
365365 directory creation. If True, `multiprocessing_options.primary_host` must be
366366 None.
367+ lightweight_initialize: If True, checkpoint step metadata is not
368+ read on
369+ CheckpointManager initialization during checkpoint info loading. This is
370+ useful to improve init performance
371+ when there are O(1k) or more existing checkpoint step present and checkpoint
372+ info properties like `time` and `metrics` are not needed.
367373 """
368374
369375 save_interval_steps : int = 1
@@ -404,6 +410,7 @@ class CheckpointManagerOptions:
404410 # TODO(b/428061876) Remove this option.
405411 enable_should_save_is_saving_in_progress_check : bool = True
406412 enable_per_process_directory_creation : bool = False
413+ lightweight_initialize : bool = False
407414
408415 def __post_init__ (self ):
409416 step_name_format_single_host_load_and_broadcast = (
@@ -879,7 +886,9 @@ def __init__(
879886 )
880887
881888 self ._checkpoints = checkpoint_info .CheckpointInfos (
882- self ._load_checkpoint_infos ()
889+ self ._load_checkpoint_infos (
890+ skip_metadata_read = self ._options .lightweight_initialize
891+ )
883892 )
884893
885894 self ._metadata_dir = self .directory / METADATA_ITEM_NAME
@@ -1764,11 +1773,17 @@ def _track_best(self):
17641773 """Returns true if we should track the best checkpoints by given metric."""
17651774 return self ._options .best_fn is not None
17661775
1767- def _load_checkpoint_infos (self ) -> List [CheckpointInfo ]:
1776+ def _load_checkpoint_infos (
1777+ self , skip_metadata_read = False
1778+ ) -> List [CheckpointInfo ]:
17681779 """Loads a list of CheckpointInfo for existing checkpoints.
17691780
17701781 If none are present, returns empty list.
17711782
1783+ Args:
1784+ skip_metadata_read: If True, will not read metadata from disk to build
1785+ checkpoint infos.
1786+
17721787 Returns:
17731788 a list of CheckpointInfo, sorted by increasing step.
17741789 """
@@ -1778,11 +1793,18 @@ def _load_checkpoint_infos(self) -> List[CheckpointInfo]:
17781793 step_metadatas = self ._step_name_format .find_all (self .directory )
17791794
17801795 def build_checkpoint_info (step_metadata ):
1781- return CheckpointInfo (
1782- step = step_metadata .step ,
1783- time = step_metadata .commit_timestamp ,
1784- metrics = self .metrics (step_metadata .step ),
1785- )
1796+ if skip_metadata_read :
1797+ return CheckpointInfo (
1798+ step = step_metadata .step ,
1799+ time = datetime .datetime .min ,
1800+ metrics = None ,
1801+ )
1802+ else :
1803+ return CheckpointInfo (
1804+ step = step_metadata .step ,
1805+ time = step_metadata .commit_timestamp ,
1806+ metrics = self .metrics (step_metadata .step ),
1807+ )
17861808
17871809 with concurrent .futures .ThreadPoolExecutor () as executor :
17881810 checkpoint_infos = list (
0 commit comments