|
64 | 64 | "id": "eIPQd4Ed6-lr" |
65 | 65 | }, |
66 | 66 | "source": [ |
67 | | - "## 2. Checkpointing: Saving \u0026 Loading Training Progress\n", |
| 67 | + "## 2. Checkpointing: Saving & Loading Training Progress\n", |
68 | 68 | "\n", |
69 | 69 | "### 2.1 PyTorch Recap: Checkpointing\n", |
70 | 70 | "\n", |
|
103 | 103 | "import torch.optim as optim\n", |
104 | 104 | "import os\n", |
105 | 105 | "import tempfile\n", |
106 | | - "import shutil\n", |
107 | | - "\n", |
| 106 | + "import shutil" |
| 107 | + ] |
| 108 | + }, |
| 109 | + { |
| 110 | + "metadata": { |
| 111 | + "id": "mA91e7BeNm1V" |
| 112 | + }, |
| 113 | + "cell_type": "code", |
| 114 | + "source": [ |
108 | 115 | "# Define a simple model\n", |
109 | 116 | "class SimpleNet(nn.Module):\n", |
110 | 117 | " def __init__(self):\n", |
|
118 | 125 | "\n", |
119 | 126 | "# Create model and optimizer\n", |
120 | 127 | "model = SimpleNet()\n", |
121 | | - "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", |
122 | | - "\n", |
| 128 | + "optimizer = optim.Adam(model.parameters(), lr=0.001)" |
| 129 | + ], |
| 130 | + "outputs": [], |
| 131 | + "execution_count": null |
| 132 | + }, |
| 133 | + { |
| 134 | + "metadata": { |
| 135 | + "id": "nUG5PQHsNqeu" |
| 136 | + }, |
| 137 | + "cell_type": "code", |
| 138 | + "source": [ |
123 | 139 | "# Simulate some training\n", |
124 | 140 | "dummy_input = torch.randn(32, 10)\n", |
125 | 141 | "dummy_target = torch.randn(32, 1)\n", |
|
131 | 147 | " loss = loss_fn(output, dummy_target)\n", |
132 | 148 | " loss.backward()\n", |
133 | 149 | " optimizer.step()\n", |
134 | | - " print(f\"Step {step}, Loss: {loss.item():.4f}\")\n", |
135 | | - "\n", |
| 150 | + " print(f\"Step {step}, Loss: {loss.item():.4f}\")" |
| 151 | + ], |
| 152 | + "outputs": [], |
| 153 | + "execution_count": null |
| 154 | + }, |
| 155 | + { |
| 156 | + "metadata": { |
| 157 | + "id": "4fOzySHeNtda" |
| 158 | + }, |
| 159 | + "cell_type": "code", |
| 160 | + "source": [ |
136 | 161 | "# Save checkpoint\n", |
137 | 162 | "tmpdir = tempfile.mkdtemp()\n", |
138 | 163 | "checkpoint_path = os.path.join(tmpdir, 'pytorch_checkpoint.pth')\n", |
|
154 | 179 | "\n", |
155 | 180 | "print(f\"Loaded checkpoint from step {step} with loss {loss:.4f}\")\n", |
156 | 181 | "shutil.rmtree(tmpdir)" |
157 | | - ] |
| 182 | + ], |
| 183 | + "outputs": [], |
| 184 | + "execution_count": null |
158 | 185 | }, |
159 | 186 | { |
160 | 187 | "cell_type": "markdown", |
|
180 | 207 | "from flax.training import train_state\n", |
181 | 208 | "import tempfile\n", |
182 | 209 | "import shutil\n", |
183 | | - "import os\n", |
184 | | - "\n", |
| 210 | + "import os" |
| 211 | + ], |
| 212 | + "metadata": { |
| 213 | + "colab": { |
| 214 | + "base_uri": "https://localhost:8080/" |
| 215 | + }, |
| 216 | + "id": "AMf1ZUqf7-54", |
| 217 | + "outputId": "8096165c-ea8a-49f1-f661-c9063a7e5bd8" |
| 218 | + }, |
| 219 | + "execution_count": null, |
| 220 | + "outputs": [ |
| 221 | + { |
| 222 | + "output_type": "stream", |
| 223 | + "name": "stderr", |
| 224 | + "text": [ |
| 225 | + "WARNING:absl:[process=0][thread=MainThread][operation_id=1] _SignalingThread.join() waiting for signals ([]) blocking the main thread will slow down blocking save times. This is likely due to main thread calling result() on a CommitFuture.\n" |
| 226 | + ] |
| 227 | + }, |
| 228 | + { |
| 229 | + "output_type": "stream", |
| 230 | + "name": "stdout", |
| 231 | + "text": [ |
| 232 | + "Step 1, Loss: 1.7981\n", |
| 233 | + "Step 2, Loss: 1.7822\n", |
| 234 | + "Step 3, Loss: 1.7664\n", |
| 235 | + "Step 4, Loss: 1.7507\n", |
| 236 | + "Step 5, Loss: 1.7352\n", |
| 237 | + "\n", |
| 238 | + "Saved checkpoint to /tmp/tmpphbau1yv\n" |
| 239 | + ] |
| 240 | + }, |
| 241 | + { |
| 242 | + "output_type": "stream", |
| 243 | + "name": "stderr", |
| 244 | + "text": [ |
| 245 | + "/usr/local/lib/python3.12/dist-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1269: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n", |
| 246 | + " warnings.warn(\n" |
| 247 | + ] |
| 248 | + }, |
| 249 | + { |
| 250 | + "output_type": "stream", |
| 251 | + "name": "stdout", |
| 252 | + "text": [ |
| 253 | + "Loaded checkpoint at step 5\n", |
| 254 | + "Parameters match: True\n" |
| 255 | + ] |
| 256 | + } |
| 257 | + ] |
| 258 | + }, |
| 259 | + { |
| 260 | + "metadata": { |
| 261 | + "id": "BzBcIeKNNSgy" |
| 262 | + }, |
| 263 | + "cell_type": "code", |
| 264 | + "source": [ |
185 | 265 | "# Define a simple model using Flax\n", |
186 | 266 | "class SimpleNet(fnn.Module):\n", |
187 | 267 | " @fnn.compact\n", |
|
202 | 282 | " apply_fn=model.apply,\n", |
203 | 283 | " params=params,\n", |
204 | 284 | " tx=tx\n", |
205 | | - ")\n", |
206 | | - "\n", |
| 285 | + ")" |
| 286 | + ], |
| 287 | + "outputs": [], |
| 288 | + "execution_count": null |
| 289 | + }, |
| 290 | + { |
| 291 | + "metadata": { |
| 292 | + "id": "o1EPp6wfNZrJ" |
| 293 | + }, |
| 294 | + "cell_type": "code", |
| 295 | + "source": [ |
207 | 296 | "# Define the Training Step\n", |
208 | 297 | "@jax.jit\n", |
209 | 298 | "def train_step(state, batch_input, batch_target):\n", |
|
219 | 308 | "dummy_target = jax.random.normal(key, (32, 1))\n", |
220 | 309 | "for _ in range(5):\n", |
221 | 310 | " state, loss = train_step(state, dummy_input, dummy_target)\n", |
222 | | - " print(f\"Step {state.step}, Loss: {loss:.4f}\")\n", |
223 | | - "\n", |
| 311 | + " print(f\"Step {state.step}, Loss: {loss:.4f}\")" |
| 312 | + ], |
| 313 | + "outputs": [], |
| 314 | + "execution_count": null |
| 315 | + }, |
| 316 | + { |
| 317 | + "metadata": { |
| 318 | + "id": "b6TkRRaHNd0r" |
| 319 | + }, |
| 320 | + "cell_type": "code", |
| 321 | + "source": [ |
224 | 322 | "# Initialize a PyTreeCheckpointer, designed for saving single PyTrees.\n", |
225 | 323 | "checkpointer = ocp.PyTreeCheckpointer()\n", |
226 | 324 | "checkpoint_dir = tempfile.mkdtemp()\n", |
|
245 | 343 | "# Clean up the temporary directory.\n", |
246 | 344 | "shutil.rmtree(checkpoint_dir)" |
247 | 345 | ], |
248 | | - "metadata": { |
249 | | - "colab": { |
250 | | - "base_uri": "https://localhost:8080/" |
251 | | - }, |
252 | | - "id": "AMf1ZUqf7-54", |
253 | | - "outputId": "8096165c-ea8a-49f1-f661-c9063a7e5bd8" |
254 | | - }, |
255 | | - "execution_count": null, |
256 | | - "outputs": [ |
257 | | - { |
258 | | - "output_type": "stream", |
259 | | - "name": "stderr", |
260 | | - "text": [ |
261 | | - "WARNING:absl:[process=0][thread=MainThread][operation_id=1] _SignalingThread.join() waiting for signals ([]) blocking the main thread will slow down blocking save times. This is likely due to main thread calling result() on a CommitFuture.\n" |
262 | | - ] |
263 | | - }, |
264 | | - { |
265 | | - "output_type": "stream", |
266 | | - "name": "stdout", |
267 | | - "text": [ |
268 | | - "Step 1, Loss: 1.7981\n", |
269 | | - "Step 2, Loss: 1.7822\n", |
270 | | - "Step 3, Loss: 1.7664\n", |
271 | | - "Step 4, Loss: 1.7507\n", |
272 | | - "Step 5, Loss: 1.7352\n", |
273 | | - "\n", |
274 | | - "Saved checkpoint to /tmp/tmpphbau1yv\n" |
275 | | - ] |
276 | | - }, |
277 | | - { |
278 | | - "output_type": "stream", |
279 | | - "name": "stderr", |
280 | | - "text": [ |
281 | | - "/usr/local/lib/python3.12/dist-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1269: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n", |
282 | | - " warnings.warn(\n" |
283 | | - ] |
284 | | - }, |
285 | | - { |
286 | | - "output_type": "stream", |
287 | | - "name": "stdout", |
288 | | - "text": [ |
289 | | - "Loaded checkpoint at step 5\n", |
290 | | - "Parameters match: True\n" |
291 | | - ] |
292 | | - } |
293 | | - ] |
| 346 | + "outputs": [], |
| 347 | + "execution_count": null |
294 | 348 | }, |
295 | 349 | { |
296 | 350 | "cell_type": "markdown", |
|
354 | 408 | "\n", |
355 | 409 | " if step % 6 == 0:\n", |
356 | 410 | " checkpoint_manager.save(step, args=ocp.args.StandardSave(state))\n", |
357 | | - " print(f\" -\u003e Saved checkpoint for step {step}\")\n", |
| 411 | + " print(f\" -> Saved checkpoint for step {step}\")\n", |
358 | 412 | "\n", |
359 | 413 | "\n", |
360 | 414 | "print(f\"\\nAvailable checkpoints: {checkpoint_manager.all_steps()}\")\n", |
|
385 | 439 | "Step 4: Loss: 1.7507\n", |
386 | 440 | "Step 5: Loss: 1.7352\n", |
387 | 441 | "Step 6: Loss: 1.7198\n", |
388 | | - " -\u003e Saved checkpoint for step 6\n", |
| 442 | + " -> Saved checkpoint for step 6\n", |
389 | 443 | "Step 7: Loss: 1.7046\n", |
390 | 444 | "Step 8: Loss: 1.6895\n", |
391 | 445 | "Step 9: Loss: 1.6746\n", |
392 | 446 | "Step 10: Loss: 1.6598\n", |
393 | 447 | "Step 11: Loss: 1.6452\n", |
394 | 448 | "Step 12: Loss: 1.6308\n", |
395 | | - " -\u003e Saved checkpoint for step 12\n", |
| 449 | + " -> Saved checkpoint for step 12\n", |
396 | 450 | "Step 13: Loss: 1.6165\n", |
397 | 451 | "Step 14: Loss: 1.6024\n", |
398 | 452 | "Step 15: Loss: 1.5885\n", |
399 | 453 | "Step 16: Loss: 1.5749\n", |
400 | 454 | "Step 17: Loss: 1.5614\n", |
401 | 455 | "Step 18: Loss: 1.5480\n", |
402 | | - " -\u003e Saved checkpoint for step 18\n", |
| 456 | + " -> Saved checkpoint for step 18\n", |
403 | 457 | "Step 19: Loss: 1.5348\n", |
404 | 458 | "Step 20: Loss: 1.5218\n", |
405 | 459 | "\n", |
|
0 commit comments