Skip to content

Commit a4ece16

Browse files
ChromeHeartsOrbax Authors
authored andcommitted
Improve V1 docstrings
PiperOrigin-RevId: 839968361
1 parent b3bb2f5 commit a4ece16

39 files changed

+445
-334
lines changed

checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ class Context(epy.ContextManager):
4747
with ocp.Context(...):
4848
ocp.save_pytree(...)
4949
50-
Creating a new `Context` within an existing `Context` sets all parameters from
51-
scratch; it does not inherit properties from the parent `Context`. To achieve
52-
this, use::
50+
Creating a new :py:class:`.Context` within an existing :py:class:`.Context`
51+
sets all parameters from scratch; it does not inherit properties from the
52+
parent :py:class:`.Context`. To achieve this, use::
5353
5454
with Context(**some_properties) as outer_ctx:
5555
with Context(outer_ctx, **other) as inner_ctx:
@@ -59,7 +59,7 @@ class Context(epy.ContextManager):
5959
properties modified in the `dataclasses.replace` call.
6060
6161
NOTE: The context is not shared across threads. In other words, the whole
62-
context block must be executed in the same thread. Following example will
62+
context block must be executed in the same thread. The following example will
6363
not work as expected::
6464
6565
executor = ThreadPoolExecutor()

checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class AsyncOptions:
4343
post_finalization_callback:
4444
A function that is called after the async save operation is complete.
4545
create_directories_asynchronously:
46-
If true, create directories asynchronously in the background.
46+
If True, creates directories asynchronously in the background.
4747
"""
4848

4949
timeout_secs: int = 600 # 10 minutes.
@@ -67,15 +67,16 @@ class MultiprocessingOptions:
6767
all hosts will be considered as primary. It's useful in the case that all
6868
hosts are only working with local storage.
6969
active_processes:
70-
A set of process indices (corresponding to `multihost.process_index()`) over
71-
which `CheckpointManager` is expected to be called. This makes it possible
72-
to have a `CheckpointManager` instance that runs over a subset of processes,
73-
rather than all processes as it is normally expected to do. If specified,
74-
`primary_host` must belong to `active_processes`.
70+
A set of process indices (corresponding to :py:func:`.process_index`) over
71+
which :py:class:`~.v1.training.Checkpointer` is expected to be called.
72+
This makes it possible to have a :py:class:`~.v1.training.Checkpointer`
73+
instance that runs over a subset of processes, rather than all processes as
74+
it is normally expected to do. If specified, `primary_host` must belong to
75+
`active_processes`.
7576
barrier_sync_key_prefix:
7677
A string to be prepended to the barrier sync key used to synchronize
7778
processes. This is useful to avoid collisions with other barrier syncs if
78-
another CheckpointManager is being used concurrently.
79+
another :py:class:`~.v1.training.Checkpointer` is being used concurrently.
7980
"""
8081

8182
primary_host: int | None = 0
@@ -102,20 +103,20 @@ class FileOptions:
102103
https://github.com/google/etils/blob/main/etils/epath/backend.py if your
103104
path is supported. default=None.
104105
temporary_path_class:
105-
A class that is used to create and finallize temporary paths, and to
106-
ensure atomicity.
106+
A class that is used to create and finalize temporary paths, and to ensure
107+
atomicity.
107108
path_class:
108-
The implementation of `path_types.Path` to use. Defaults to
109-
`etils.epath.Path`, but may be overridden to some other subclass of
110-
`path_types.Path`.
109+
The implementation of :py:class:`~.v1.path.Path` to use. Defaults to
110+
`etils.epath.Path`, but may be overridden to some other subclass of
111+
:py:class:`~.v1.path.Path`.
111112
"""
112113

113114
path_permission_mode: int | None = None
114115
temporary_path_class: type[atomicity_types.TemporaryPath] | None = None
115116
path_class: type[path_types.Path] = epath.Path
116117

117118
def v0(self) -> v0_options_lib.FileOptions:
118-
"""Converts this FileOptions to a v0 FileOptions."""
119+
"""Converts this :py:class:`~.v1.options.FileOptions` to a v0 :py:class:`~orbax.checkpoint.options.FileOptions`."""
119120
return v0_options_lib.FileOptions(
120121
path_permission_mode=self.path_permission_mode,
121122
)
@@ -141,11 +142,12 @@ class Saving:
141142
142143
create_array_storage_options_fn:
143144
A function that is called in order to create
144-
`ArrayOptions.Saving.StorageOptions` for each leaf in a PyTree, when it is
145+
:py:class:`.ArrayOptions.Saving.StorageOptions` for each leaf in a PyTree,
146+
when it is
145147
being saved. It is called similar to:
146148
`jax.tree.map_with_path(create_array_storage_options_fn, pytree_to_save)`.
147149
If provided, it overrides any default settings in
148-
`ArrayOptions.Saving.StorageOptions`.
150+
:py:class:`.ArrayOptions.Saving.StorageOptions`.
149151
pytree_metadata_options: Options for managing PyTree metadata.
150152
"""
151153

@@ -230,19 +232,19 @@ class StorageOptions:
230232
231233
dtype:
232234
If provided, casts the parameter to the given dtype before saving.
233-
Note that the parameter must be compatible with the given type (e.g.
234-
jnp.bfloat16 is not compatible with np.ndarray).
235+
Note that the parameter must be compatible with the given type (e.g.,
236+
`jnp.bfloat16` is not compatible with `np.ndarray`).
235237
chunk_byte_size:
236238
This is an experimental feature that automatically chooses the largest
237-
chunk shape possible, while keeping the chunk byte size less than or
238-
equal to the specified chunk_byte_size. Both the write_chunk_shape and
239-
read_chunk_shape are automatically set to the chosen shape. This uses a
240-
greedy algorithm that prioritizes splitting the largest dimensions
239+
possible chunk shape while keeping the chunk byte size less than or
240+
equal to the specified `chunk_byte_size`. Both `write_chunk_shape` and
241+
`read_chunk_shape` are automatically set to the chosen shape. This uses
242+
a greedy algorithm that prioritizes splitting the largest dimensions
241243
first.
242244
shard_axes:
243-
An optional list of axes that should be prioritized when sharding array
244-
for storage. If empty, storage sharding implementation will prioritize
245-
axes which are already sharded.
245+
An optional list of axes that should be prioritized when sharding an
246+
array for storage. If empty, the storage sharding implementation will
247+
prioritize axes which are already sharded.
246248
"""
247249

248250
dtype: np.typing.DTypeLike | None = None
@@ -322,9 +324,9 @@ class CheckpointablesOptions:
322324
first because it is registered first.
323325
324326
Attributes:
325-
registry: A `CheckpointableHandlerRegistry` that is used to resolve
326-
`CheckpointableHandler` classes for each provided `checkpointable` during
327-
saving and loading.
327+
registry: A :py:class:`.CheckpointableHandlerRegistry` that is used to
328+
resolve :py:class:`.CheckpointableHandler` classes for each provided
329+
`checkpointable` during saving and loading.
328330
"""
329331

330332
registry: registration.CheckpointableHandlerRegistry = dataclasses.field(

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/compatibility.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232

3333
class _PathAwaitingCreation(path_types.PathAwaitingCreation):
34-
"""Implementation of `PathAwaitingCreation` that awaits contracted signals."""
34+
"""Implementation of :py:class:`~.v1.path.PathAwaitingCreation` that awaits contracted signals."""
3535

3636
def __init__(self, path: path_types.Path, operation_id: str):
3737
self._path = path
@@ -56,7 +56,7 @@ def path(self) -> path_types.Path:
5656
class CompatibilityCheckpointHandler(
5757
async_checkpoint_handler.AsyncCheckpointHandler
5858
):
59-
"""A V0 CheckpointHandler that wraps a V1 CheckpointableHandler."""
59+
"""A V0 :py:class:`~orbax.checkpoint.AsyncCheckpointHandler` that wraps a V1 :py:class:`~.v1.handlers.CheckpointableHandler`."""
6060

6161
def __init__(self, handler: handler_types.CheckpointableHandler):
6262
self._handler = handler

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Defines `CompositeHandler`, a helper component for saving and loading."""
15+
"""Defines CompositeHandler, a helper component for saving and loading."""
1616

1717
from __future__ import annotations
1818

@@ -74,8 +74,8 @@ async def _create_orbax_identifier_file(
7474
class CompositeHandler:
7575
"""CompositeHandler.
7676
77-
This class is a helper component for `save_checkpointables` and
78-
`load_checkpointables`. It performs a few core functions:
77+
This class is a helper component for :py:func:`~.v1.save_checkpointables`,
78+
:py:func:`~.v1.load_checkpointables`, etc. It performs a few core functions:
7979
- Resolves handlers for saving and loading.
8080
- Saves and loads checkpointables to/from individual subdirectories by
8181
delegating to the resolved handlers.

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/json_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Implementation of `CheckpointableHandler` for PyTrees."""
15+
"""Implementation of :py:class:`.CheckpointableHandler` for PyTrees."""
1616

1717
from __future__ import annotations
1818

@@ -40,7 +40,7 @@ def _get_supported_filenames(filename: str | None = None) -> list[str]:
4040

4141

4242
class JsonHandler(CheckpointableHandler[JsonType, None]):
43-
"""An implementation of `CheckpointableHandler` for Json."""
43+
"""An implementation of :py:class:`.CheckpointableHandler` for Json."""
4444

4545
def __init__(self, filename: str | None = None):
4646
self._supported_filenames = _get_supported_filenames(filename)

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/proto_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
self,
3939
filename: str = _DEFAULT_FILENAME,
4040
):
41-
"""Initializes ProtoCheckpointHandler."""
41+
"""Initializes :py:class:`.ProtoHandler`."""
4242
self._filename = filename
4343

4444
async def _background_save(

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Implementation of :py:class:`.CheckpointableHandler` for PyTrees."""
15+
"""Implementation of :py:class:`~.v1.handlers.CheckpointableHandler` for PyTrees."""
1616

1717
from __future__ import annotations
1818

@@ -123,24 +123,24 @@ def create_v0_save_args(
123123
def _restore_type_by_abstract_type(
124124
abstract_checkpointable: Any,
125125
) -> Any:
126-
"""This is to allow users to override the restored type.
126+
"""Allows users to override the restored type.
127127
128-
When users pass in the `value` in the DeserializationParam, the PytreeHandler
129-
will try to restore to the specified type. T. This only supports the standard
128+
When users pass the `value` in the `DeserializationParam`, the `PyTreeHandler`
129+
will try to restore to the specified type `T`. This only supports the standard
130130
types supported by Orbax.
131131
For example:
132-
- jax.ShapeDtype -> jax.Array
133-
- NumpyAbstractType -> jax.Array
134-
- int | float | Type[int] | Type[float] -> int | float | int | float
132+
- `jax.ShapeDtype` -> `jax.Array`
133+
- `NumpyAbstractType` -> `jax.Array`
134+
- `int` | `float` | `Type[int]` | `Type[float]` -> `int` | `float` | `int` |
135+
`float`
135136
136137
Args:
137-
abstract_checkpointable: The abstract checkpointable that passed in by the
138-
user.
138+
abstract_checkpointable: The abstract checkpointable passed in by the user.
139139
140140
Returns:
141-
Return the restore_type parameter for the V0RestoreArgs. This is needed to
142-
determine which LeafHandler will eventually handle this
143-
abstract_checkpointable.
141+
Returns the `restore_type` parameter for `V0RestoreArgs`. This is needed to
142+
determine which `LeafHandler` will eventually handle this
143+
`abstract_checkpointable`.
144144
"""
145145

146146
if abstract_checkpointable is None:
@@ -315,8 +315,9 @@ async def load(
315315
abstract_checkpointable: The abstract checkpointable to load into. If
316316
None, the handler will attempt to load the entire checkpoint using the
317317
recorded metadata. Otherwise, the `abstract_checkpointable` is expected
318-
to be a PyTree of abstract leaves. See :py:class:`.LeafHandler` for more
319-
details. The abstract leaf may be a value of type `AbstractLeaf`,
318+
to be a PyTree of abstract leaves. See
319+
:py:class:`~.v1.serialization.LeafHandler` for more details. The
320+
abstract leaf may be a value of type `AbstractLeaf`,
320321
`Type[AbstractLeaf]`, or `None`. E.g. if the `AbstractLeaf` is
321322
`AbstractFoo`, it is always valid to pass `AbstractFoo()` or
322323
`AbstractFoo` or `None`. Passing the latter two indicates that metadata

0 commit comments

Comments
 (0)