Skip to content

Commit 9136aa4

Browse files
mridul-sahuOrbax Authors
authored andcommitted
Use etils iterdir and in-memory filtering instead of glob in _glob_step_paths
PiperOrigin-RevId: 842523887
1 parent 0986794 commit 9136aa4

File tree

1 file changed

+13
-5
lines changed
  • checkpoint/orbax/checkpoint/_src/path

1 file changed

+13
-5
lines changed

checkpoint/orbax/checkpoint/_src/path/step.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@
6969
]
7070

7171

72+
def _validate_base_path(base_path: epath.PathLike):
73+
"""Validates base_path and returns it as an epath.Path."""
74+
base_path = epath.Path(base_path)
75+
if not base_path.exists():
76+
raise ValueError(f'Base path {base_path} does not exist.')
77+
if not base_path.is_dir():
78+
raise ValueError(f'Base path {base_path} is not a directory.')
79+
80+
7281
@dataclasses.dataclass(frozen=True)
7382
class Metadata:
7483
"""Metadata of a step.
@@ -396,11 +405,8 @@ def _glob_step_paths(self, base_path: epath.PathLike) -> list[epath.Path]:
396405
if folder.startswith(os.path.join(path_prefix, self.step_prefix))
397406
]
398407
else:
399-
return list(
400-
epath.Path(base_path).glob(
401-
f'{step_prefix_with_underscore(self.step_prefix)}*'
402-
)
403-
)
408+
prefix = step_prefix_with_underscore(self.step_prefix)
409+
return [x for x in base_path.iterdir() if x.name.startswith(prefix)]
404410

405411
def _get_step_paths_and_total_steps(
406412
self, base_path: epath.PathLike, is_primary_host: bool
@@ -526,6 +532,7 @@ def _find_all_with_single_host_load_and_broadcast(
526532

527533
def find_all(self, base_path: epath.PathLike) -> Iterator[Metadata]:
528534
"""Returns metadata of all steps matching with name_format attributes."""
535+
_validate_base_path(base_path)
529536
# Note: the order of conjuncts is important here; we should not call
530537
# `multihost.process_count()` when `single_host_load_and_broadcast` is False
531538
# as this has the possible side effect of initializing the jax backend. See
@@ -539,6 +546,7 @@ def find_all(self, base_path: epath.PathLike) -> Iterator[Metadata]:
539546

540547
def find_step(self, base_path: epath.PathLike, step: int) -> Metadata:
541548
"""Returns the metadata for `step` or raises ValueError."""
549+
_validate_base_path(base_path)
542550
step_path = build_step_path(base_path, self, step)
543551
metadata = self._build_metadata(step_path, step=step)
544552
if metadata is not None:

0 commit comments

Comments
 (0)