We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c74477f commit 19bbd9bCopy full SHA for 19bbd9b
MANIFEST.in
@@ -1,3 +1,7 @@
1
+prune docs
2
+prune notebooks
3
+prune datasets
4
+
5
include brax/envs/assets/*.xml
6
recursive-include brax/experimental/barkour *.csv *.stl *.xml
7
recursive-include brax/test_data *.xml *.stl *.obj *.urdf
brax/training/agents/ppo/train.py
@@ -423,13 +423,6 @@ def train(
423
progress_fn=progress_fn,
424
)
425
426
- ckpt_config = checkpoint.network_config(
427
- observation_size=obs_shape,
428
- action_size=env.action_size,
429
- normalize_observations=normalize_observations,
430
- network_factory=network_factory,
431
- )
432
-
433
def minibatch_step(
434
carry,
435
data: types.Transition,
@@ -713,6 +706,12 @@ def training_epoch_with_timing(
713
706
policy_params_fn(current_step, make_policy, params)
714
707
715
708
if save_checkpoint_path is not None:
709
+ ckpt_config = checkpoint.network_config(
710
+ observation_size=obs_shape,
711
+ action_size=env.action_size,
712
+ normalize_observations=normalize_observations,
+ network_factory=network_factory,
+ )
716
checkpoint.save(
717
save_checkpoint_path, current_step, params, ckpt_config
718
brax/training/agents/sac/train.py
@@ -260,13 +260,6 @@ def train(
260
actor_loss, policy_optimizer, pmap_axis_name=_PMAP_AXIS_NAME
261
262
263
264
- observation_size=obs_size,
265
266
267
268
269
270
def sgd_step(
271
carry: Tuple[TrainingState, PRNGKey], transitions: Transition
272
) -> Tuple[Tuple[TrainingState, PRNGKey], Metrics]:
@@ -585,6 +578,12 @@ def training_epoch_with_timing(
585
578
params = _unpmap(
586
579
(training_state.normalizer_params, training_state.policy_params)
587
580
581
582
+ observation_size=obs_size,
583
584
588
checkpoint.save(checkpoint_logdir, current_step, params, ckpt_config)
589
590
# Run evals.
brax/training/checkpoint.py
@@ -71,6 +71,7 @@ def network_config(
71
72
del kwargs['preprocess_observations_fn']
73
if 'activation' in kwargs:
74
+ # TODO: Add other activations.
75
if kwargs['activation'] != defaults['activation']:
76
raise ValueError('checkpointing only supports default activation')
77
del kwargs['activation']
docs/release-notes/next-release.md
@@ -1 +1,3 @@
# Brax Release Notes
+* Fix #595, patch for checkpointing with activations other than relu when no checkpoint path is specified.
pyproject.toml
@@ -71,6 +71,23 @@ Homepage = "http://github.com/google/brax"
[tool.hatch.build.targets.wheel]
packages = ["brax"]
+exclude = [
+ "/datasets",
+ "/docs",
+ "/notebooks",
78
+ "/tests",
79
+ "brax/experimental/barkour/assets",
80
+ "brax/experimental/barkour/data",
81
+]
82
83
+[tool.hatch.build.targets.sdist]
84
85
86
87
88
89
90
91
92
[tool.isort]
93
force_single_line = true
0 commit comments