Skip to content

Commit 7da0ece

Browse files
author
Orbax Authors
committed
No public description
PiperOrigin-RevId: 842606313
1 parent c722817 commit 7da0ece

File tree

1 file changed

+5
-13
lines changed
  • checkpoint/orbax/checkpoint/_src/path

1 file changed

+5
-13
lines changed

checkpoint/orbax/checkpoint/_src/path/step.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,6 @@
6969
]
7070

7171

72-
def _validate_base_path(base_path: epath.PathLike):
73-
"""Validates base_path and returns it as an epath.Path."""
74-
base_path = epath.Path(base_path)
75-
if not base_path.exists():
76-
raise ValueError(f'Base path {base_path} does not exist.')
77-
if not base_path.is_dir():
78-
raise ValueError(f'Base path {base_path} is not a directory.')
79-
80-
8172
@dataclasses.dataclass(frozen=True)
8273
class Metadata:
8374
"""Metadata of a step.
@@ -405,8 +396,11 @@ def _glob_step_paths(self, base_path: epath.PathLike) -> list[epath.Path]:
405396
if folder.startswith(os.path.join(path_prefix, self.step_prefix))
406397
]
407398
else:
408-
prefix = step_prefix_with_underscore(self.step_prefix)
409-
return [x for x in base_path.iterdir() if x.name.startswith(prefix)]
399+
return list(
400+
epath.Path(base_path).glob(
401+
f'{step_prefix_with_underscore(self.step_prefix)}*'
402+
)
403+
)
410404

411405
def _get_step_paths_and_total_steps(
412406
self, base_path: epath.PathLike, is_primary_host: bool
@@ -532,7 +526,6 @@ def _find_all_with_single_host_load_and_broadcast(
532526

533527
def find_all(self, base_path: epath.PathLike) -> Iterator[Metadata]:
534528
"""Returns metadata of all steps matching with name_format attributes."""
535-
_validate_base_path(base_path)
536529
# Note: the order of conjuncts is important here; we should not call
537530
# `multihost.process_count()` when `single_host_load_and_broadcast` is False
538531
# as this has the possible side effect of initializing the jax backend. See
@@ -546,7 +539,6 @@ def find_all(self, base_path: epath.PathLike) -> Iterator[Metadata]:
546539

547540
def find_step(self, base_path: epath.PathLike, step: int) -> Metadata:
548541
"""Returns the metadata for `step` or raises ValueError."""
549-
_validate_base_path(base_path)
550542
step_path = build_step_path(base_path, self, step)
551543
metadata = self._build_metadata(step_path, step=step)
552544
if metadata is not None:

0 commit comments

Comments
 (0)