Skip to content

Commit ecb86c8

Browse files
BlaziusMaximusOrbax Authors
authored andcommitted
Split up large code cells to avoid cell timeouts.
PiperOrigin-RevId: 842299629
1 parent 4fa597b commit ecb86c8

File tree

1 file changed

+118
-64
lines changed

1 file changed

+118
-64
lines changed

docs/guides/checkpoint/v1/orbax_checkpointing_for_pytorch_users.ipynb

Lines changed: 118 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
"id": "eIPQd4Ed6-lr"
6565
},
6666
"source": [
67-
"## 2. Checkpointing: Saving \u0026 Loading Training Progress\n",
67+
"## 2. Checkpointing: Saving & Loading Training Progress\n",
6868
"\n",
6969
"### 2.1 PyTorch Recap: Checkpointing\n",
7070
"\n",
@@ -103,8 +103,15 @@
103103
"import torch.optim as optim\n",
104104
"import os\n",
105105
"import tempfile\n",
106-
"import shutil\n",
107-
"\n",
106+
"import shutil"
107+
]
108+
},
109+
{
110+
"metadata": {
111+
"id": "mA91e7BeNm1V"
112+
},
113+
"cell_type": "code",
114+
"source": [
108115
"# Define a simple model\n",
109116
"class SimpleNet(nn.Module):\n",
110117
" def __init__(self):\n",
@@ -118,8 +125,17 @@
118125
"\n",
119126
"# Create model and optimizer\n",
120127
"model = SimpleNet()\n",
121-
"optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
122-
"\n",
128+
"optimizer = optim.Adam(model.parameters(), lr=0.001)"
129+
],
130+
"outputs": [],
131+
"execution_count": null
132+
},
133+
{
134+
"metadata": {
135+
"id": "nUG5PQHsNqeu"
136+
},
137+
"cell_type": "code",
138+
"source": [
123139
"# Simulate some training\n",
124140
"dummy_input = torch.randn(32, 10)\n",
125141
"dummy_target = torch.randn(32, 1)\n",
@@ -131,8 +147,17 @@
131147
" loss = loss_fn(output, dummy_target)\n",
132148
" loss.backward()\n",
133149
" optimizer.step()\n",
134-
" print(f\"Step {step}, Loss: {loss.item():.4f}\")\n",
135-
"\n",
150+
" print(f\"Step {step}, Loss: {loss.item():.4f}\")"
151+
],
152+
"outputs": [],
153+
"execution_count": null
154+
},
155+
{
156+
"metadata": {
157+
"id": "4fOzySHeNtda"
158+
},
159+
"cell_type": "code",
160+
"source": [
136161
"# Save checkpoint\n",
137162
"tmpdir = tempfile.mkdtemp()\n",
138163
"checkpoint_path = os.path.join(tmpdir, 'pytorch_checkpoint.pth')\n",
@@ -154,7 +179,9 @@
154179
"\n",
155180
"print(f\"Loaded checkpoint from step {step} with loss {loss:.4f}\")\n",
156181
"shutil.rmtree(tmpdir)"
157-
]
182+
],
183+
"outputs": [],
184+
"execution_count": null
158185
},
159186
{
160187
"cell_type": "markdown",
@@ -180,8 +207,61 @@
180207
"from flax.training import train_state\n",
181208
"import tempfile\n",
182209
"import shutil\n",
183-
"import os\n",
184-
"\n",
210+
"import os"
211+
],
212+
"metadata": {
213+
"colab": {
214+
"base_uri": "https://localhost:8080/"
215+
},
216+
"id": "AMf1ZUqf7-54",
217+
"outputId": "8096165c-ea8a-49f1-f661-c9063a7e5bd8"
218+
},
219+
"execution_count": null,
220+
"outputs": [
221+
{
222+
"output_type": "stream",
223+
"name": "stderr",
224+
"text": [
225+
"WARNING:absl:[process=0][thread=MainThread][operation_id=1] _SignalingThread.join() waiting for signals ([]) blocking the main thread will slow down blocking save times. This is likely due to main thread calling result() on a CommitFuture.\n"
226+
]
227+
},
228+
{
229+
"output_type": "stream",
230+
"name": "stdout",
231+
"text": [
232+
"Step 1, Loss: 1.7981\n",
233+
"Step 2, Loss: 1.7822\n",
234+
"Step 3, Loss: 1.7664\n",
235+
"Step 4, Loss: 1.7507\n",
236+
"Step 5, Loss: 1.7352\n",
237+
"\n",
238+
"Saved checkpoint to /tmp/tmpphbau1yv\n"
239+
]
240+
},
241+
{
242+
"output_type": "stream",
243+
"name": "stderr",
244+
"text": [
245+
"/usr/local/lib/python3.12/dist-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1269: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n",
246+
" warnings.warn(\n"
247+
]
248+
},
249+
{
250+
"output_type": "stream",
251+
"name": "stdout",
252+
"text": [
253+
"Loaded checkpoint at step 5\n",
254+
"Parameters match: True\n"
255+
]
256+
}
257+
]
258+
},
259+
{
260+
"metadata": {
261+
"id": "BzBcIeKNNSgy"
262+
},
263+
"cell_type": "code",
264+
"source": [
185265
"# Define a simple model using Flax\n",
186266
"class SimpleNet(fnn.Module):\n",
187267
" @fnn.compact\n",
@@ -202,8 +282,17 @@
202282
" apply_fn=model.apply,\n",
203283
" params=params,\n",
204284
" tx=tx\n",
205-
")\n",
206-
"\n",
285+
")"
286+
],
287+
"outputs": [],
288+
"execution_count": null
289+
},
290+
{
291+
"metadata": {
292+
"id": "o1EPp6wfNZrJ"
293+
},
294+
"cell_type": "code",
295+
"source": [
207296
"# Define the Training Step\n",
208297
"@jax.jit\n",
209298
"def train_step(state, batch_input, batch_target):\n",
@@ -219,8 +308,17 @@
219308
"dummy_target = jax.random.normal(key, (32, 1))\n",
220309
"for _ in range(5):\n",
221310
" state, loss = train_step(state, dummy_input, dummy_target)\n",
222-
" print(f\"Step {state.step}, Loss: {loss:.4f}\")\n",
223-
"\n",
311+
" print(f\"Step {state.step}, Loss: {loss:.4f}\")"
312+
],
313+
"outputs": [],
314+
"execution_count": null
315+
},
316+
{
317+
"metadata": {
318+
"id": "b6TkRRaHNd0r"
319+
},
320+
"cell_type": "code",
321+
"source": [
224322
"# Initialize a PyTreeCheckpointer, designed for saving single PyTrees.\n",
225323
"checkpointer = ocp.PyTreeCheckpointer()\n",
226324
"checkpoint_dir = tempfile.mkdtemp()\n",
@@ -245,52 +343,8 @@
245343
"# Clean up the temporary directory.\n",
246344
"shutil.rmtree(checkpoint_dir)"
247345
],
248-
"metadata": {
249-
"colab": {
250-
"base_uri": "https://localhost:8080/"
251-
},
252-
"id": "AMf1ZUqf7-54",
253-
"outputId": "8096165c-ea8a-49f1-f661-c9063a7e5bd8"
254-
},
255-
"execution_count": null,
256-
"outputs": [
257-
{
258-
"output_type": "stream",
259-
"name": "stderr",
260-
"text": [
261-
"WARNING:absl:[process=0][thread=MainThread][operation_id=1] _SignalingThread.join() waiting for signals ([]) blocking the main thread will slow down blocking save times. This is likely due to main thread calling result() on a CommitFuture.\n"
262-
]
263-
},
264-
{
265-
"output_type": "stream",
266-
"name": "stdout",
267-
"text": [
268-
"Step 1, Loss: 1.7981\n",
269-
"Step 2, Loss: 1.7822\n",
270-
"Step 3, Loss: 1.7664\n",
271-
"Step 4, Loss: 1.7507\n",
272-
"Step 5, Loss: 1.7352\n",
273-
"\n",
274-
"Saved checkpoint to /tmp/tmpphbau1yv\n"
275-
]
276-
},
277-
{
278-
"output_type": "stream",
279-
"name": "stderr",
280-
"text": [
281-
"/usr/local/lib/python3.12/dist-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1269: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n",
282-
" warnings.warn(\n"
283-
]
284-
},
285-
{
286-
"output_type": "stream",
287-
"name": "stdout",
288-
"text": [
289-
"Loaded checkpoint at step 5\n",
290-
"Parameters match: True\n"
291-
]
292-
}
293-
]
346+
"outputs": [],
347+
"execution_count": null
294348
},
295349
{
296350
"cell_type": "markdown",
@@ -354,7 +408,7 @@
354408
"\n",
355409
" if step % 6 == 0:\n",
356410
" checkpoint_manager.save(step, args=ocp.args.StandardSave(state))\n",
357-
" print(f\" -\u003e Saved checkpoint for step {step}\")\n",
411+
" print(f\" -> Saved checkpoint for step {step}\")\n",
358412
"\n",
359413
"\n",
360414
"print(f\"\\nAvailable checkpoints: {checkpoint_manager.all_steps()}\")\n",
@@ -385,21 +439,21 @@
385439
"Step 4: Loss: 1.7507\n",
386440
"Step 5: Loss: 1.7352\n",
387441
"Step 6: Loss: 1.7198\n",
388-
" -\u003e Saved checkpoint for step 6\n",
442+
" -> Saved checkpoint for step 6\n",
389443
"Step 7: Loss: 1.7046\n",
390444
"Step 8: Loss: 1.6895\n",
391445
"Step 9: Loss: 1.6746\n",
392446
"Step 10: Loss: 1.6598\n",
393447
"Step 11: Loss: 1.6452\n",
394448
"Step 12: Loss: 1.6308\n",
395-
" -\u003e Saved checkpoint for step 12\n",
449+
" -> Saved checkpoint for step 12\n",
396450
"Step 13: Loss: 1.6165\n",
397451
"Step 14: Loss: 1.6024\n",
398452
"Step 15: Loss: 1.5885\n",
399453
"Step 16: Loss: 1.5749\n",
400454
"Step 17: Loss: 1.5614\n",
401455
"Step 18: Loss: 1.5480\n",
402-
" -\u003e Saved checkpoint for step 18\n",
456+
" -> Saved checkpoint for step 18\n",
403457
"Step 19: Loss: 1.5348\n",
404458
"Step 20: Loss: 1.5218\n",
405459
"\n",

0 commit comments

Comments
 (0)