Skip to content

Commit 19bbd9b

Browse files
Brax Teambtaba
authored andcommitted
Internal change
PiperOrigin-RevId: 747504144 Change-Id: Ia6f000e46f84e5f62793dbf7ac347a5ca8d8f540
1 parent c74477f commit 19bbd9b

File tree

6 files changed

+36
-14
lines changed

6 files changed

+36
-14
lines changed

MANIFEST.in

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
prune docs
2+
prune notebooks
3+
prune datasets
4+
15
include brax/envs/assets/*.xml
26
recursive-include brax/experimental/barkour *.csv *.stl *.xml
37
recursive-include brax/test_data *.xml *.stl *.obj *.urdf

brax/training/agents/ppo/train.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -423,13 +423,6 @@ def train(
423423
progress_fn=progress_fn,
424424
)
425425

426-
ckpt_config = checkpoint.network_config(
427-
observation_size=obs_shape,
428-
action_size=env.action_size,
429-
normalize_observations=normalize_observations,
430-
network_factory=network_factory,
431-
)
432-
433426
def minibatch_step(
434427
carry,
435428
data: types.Transition,
@@ -713,6 +706,12 @@ def training_epoch_with_timing(
713706
policy_params_fn(current_step, make_policy, params)
714707

715708
if save_checkpoint_path is not None:
709+
ckpt_config = checkpoint.network_config(
710+
observation_size=obs_shape,
711+
action_size=env.action_size,
712+
normalize_observations=normalize_observations,
713+
network_factory=network_factory,
714+
)
716715
checkpoint.save(
717716
save_checkpoint_path, current_step, params, ckpt_config
718717
)

brax/training/agents/sac/train.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -260,13 +260,6 @@ def train(
260260
actor_loss, policy_optimizer, pmap_axis_name=_PMAP_AXIS_NAME
261261
)
262262

263-
ckpt_config = checkpoint.network_config(
264-
observation_size=obs_size,
265-
action_size=env.action_size,
266-
normalize_observations=normalize_observations,
267-
network_factory=network_factory,
268-
)
269-
270263
def sgd_step(
271264
carry: Tuple[TrainingState, PRNGKey], transitions: Transition
272265
) -> Tuple[Tuple[TrainingState, PRNGKey], Metrics]:
@@ -585,6 +578,12 @@ def training_epoch_with_timing(
585578
params = _unpmap(
586579
(training_state.normalizer_params, training_state.policy_params)
587580
)
581+
ckpt_config = checkpoint.network_config(
582+
observation_size=obs_size,
583+
action_size=env.action_size,
584+
normalize_observations=normalize_observations,
585+
network_factory=network_factory,
586+
)
588587
checkpoint.save(checkpoint_logdir, current_step, params, ckpt_config)
589588

590589
# Run evals.

brax/training/checkpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def network_config(
7171
)
7272
del kwargs['preprocess_observations_fn']
7373
if 'activation' in kwargs:
74+
# TODO: Add other activations.
7475
if kwargs['activation'] != defaults['activation']:
7576
raise ValueError('checkpointing only supports default activation')
7677
del kwargs['activation']

docs/release-notes/next-release.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
# Brax Release Notes
2+
3+
* Fix #595, patch for checkpointing with activations other than relu when no checkpoint path is specified.

pyproject.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,23 @@ Homepage = "http://github.com/google/brax"
7171

7272
[tool.hatch.build.targets.wheel]
7373
packages = ["brax"]
74+
exclude = [
75+
"/datasets",
76+
"/docs",
77+
"/notebooks",
78+
"/tests",
79+
"brax/experimental/barkour/assets",
80+
"brax/experimental/barkour/data",
81+
]
82+
83+
[tool.hatch.build.targets.sdist]
84+
exclude = [
85+
"/datasets",
86+
"/docs",
87+
"/notebooks",
88+
"brax/experimental/barkour/assets",
89+
"brax/experimental/barkour/data",
90+
]
7491

7592
[tool.isort]
7693
force_single_line = true

0 commit comments

Comments
 (0)