Skip to content

Commit 07ff391

Browse files
mridul-sahuOrbax Authors
authored andcommitted
No public description
PiperOrigin-RevId: 842747698
1 parent 3c1c4c9 commit 07ff391

File tree

1 file changed

+15
-5
lines changed
  • checkpoint/orbax/checkpoint/_src/path

1 file changed

+15
-5
lines changed

checkpoint/orbax/checkpoint/_src/path/step.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@
6969
]
7070

7171

72+
def _is_valid_base_path(base_path: epath.PathLike) -> bool:
73+
"""Validates base_path and returns it as an epath.Path."""
74+
base_path = epath.Path(base_path)
75+
return base_path.exists() and base_path.is_dir()
76+
77+
7278
@dataclasses.dataclass(frozen=True)
7379
class Metadata:
7480
"""Metadata of a step.
@@ -396,11 +402,8 @@ def _glob_step_paths(self, base_path: epath.PathLike) -> list[epath.Path]:
396402
if folder.startswith(os.path.join(path_prefix, self.step_prefix))
397403
]
398404
else:
399-
return list(
400-
epath.Path(base_path).glob(
401-
f'{step_prefix_with_underscore(self.step_prefix)}*'
402-
)
403-
)
405+
prefix = step_prefix_with_underscore(self.step_prefix)
406+
return [x for x in base_path.iterdir() if x.name.startswith(prefix)]
404407

405408
def _get_step_paths_and_total_steps(
406409
self, base_path: epath.PathLike, is_primary_host: bool
@@ -526,6 +529,8 @@ def _find_all_with_single_host_load_and_broadcast(
526529

527530
def find_all(self, base_path: epath.PathLike) -> Iterator[Metadata]:
528531
"""Returns metadata of all steps matching with name_format attributes."""
532+
if not _is_valid_base_path(base_path):
533+
return iter([])
529534
# Note: the order of conjuncts is important here; we should not call
530535
# `multihost.process_count()` when `single_host_load_and_broadcast` is False
531536
# as this has the possible side effect of initializing the jax backend. See
@@ -539,6 +544,11 @@ def find_all(self, base_path: epath.PathLike) -> Iterator[Metadata]:
539544

540545
def find_step(self, base_path: epath.PathLike, step: int) -> Metadata:
541546
"""Returns the metadata for `step` or raises ValueError."""
547+
if not _is_valid_base_path(base_path):
548+
raise ValueError(
549+
f'Invalid base_path: {base_path} does not exist or is not a'
550+
' directory.'
551+
)
542552
step_path = build_step_path(base_path, self, step)
543553
metadata = self._build_metadata(step_path, step=step)
544554
if metadata is not None:

0 commit comments

Comments
 (0)