@@ -237,69 +237,69 @@ jobs:
237237 "context": "github-actions/build"
238238 }'
239239
240- multiprocess-checkpoint-benchmarks :
241- name : " multiprocess-checkpoint-benchmarks (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
242- runs-on : linux-g2-16-l4-1gpu-x4
243- # runs-on: linux-x86-ct5lp-4tpu-x4
244- container : us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e
245- defaults :
246- run :
247- working-directory : checkpoint
248- strategy :
249- matrix :
250- python-version : ["3.10", "3.11", "3.12"]
251- jax-version : ["0.6.0"]
252- steps :
253- - uses : actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
254- - name : Set up Python ${{ matrix.python-version }}
255- uses : actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
256- with :
257- python-version : ${{ matrix.python-version }}
258- - name : Install dependencies
259- run : |
260- pip install -e .
261- pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
262- pip uninstall -y orbax
263- if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
264- pip install -U jax[k8s,cuda12] jaxlib
265- elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
266- pip install -U --pre jax[k8s,cuda12] jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
267- else
268- pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
269- fi
270- pip install gcsfs
271- pip install portpicker
272- - name : Run benchmarks
273- env :
274- GCS_BUCKET_PATH : gs://orbax-benchmarks/benchmark-results/${{ github.run_id }}
275- run : |
276- cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
277- cd ../../../../..
278- # python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
279- # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
280- # The below step just reports the success or failure of tests as a "commit status".
281- # This is needed for copybara integration.
282- - name : Run multiprocess tests
283- env :
284- TEST_TMPDIR : gs://orbax-benchmarks/unit-tests/${{ github.run_id }}
285- run : |
286- python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)"
287- # python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py;"
288- # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
289- # python -m pytest orbax/checkpoint/checkpoint_manager_test.py
290- - name : Report success or failure as github status
291- if : always()
292- shell : bash
293- run : |
294- status="${{ job.status }}"
295- lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
296- curl -sS --request POST \
297- --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
298- --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
299- --header 'content-type: application/json' \
300- --data '{
301- "state": "'$lowercase_status'",
302- "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
303- "description": "'$status'",
304- "context": "github-actions/build"
305- }'
240+ # multiprocess-checkpoint-benchmarks:
241+ # name: "multiprocess-checkpoint-benchmarks (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
242+ # runs-on: linux-g2-16-l4-1gpu-x4
243+ # # runs-on: linux-x86-ct5lp-4tpu-x4
244+ # container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e
245+ # defaults:
246+ # run:
247+ # working-directory: checkpoint
248+ # strategy:
249+ # matrix:
250+ # python-version: ["3.10", "3.11", "3.12"]
251+ # jax-version: ["0.6.0"]
252+ # steps:
253+ # - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
254+ # - name: Set up Python ${{ matrix.python-version }}
255+ # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
256+ # with:
257+ # python-version: ${{ matrix.python-version }}
258+ # - name: Install dependencies
259+ # run: |
260+ # pip install -e .
261+ # pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
262+ # pip uninstall -y orbax
263+ # if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
264+ # pip install -U jax[k8s,cuda12] jaxlib
265+ # elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
266+ # pip install -U --pre jax[k8s,cuda12] jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
267+ # else
268+ # pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
269+ # fi
270+ # pip install gcsfs
271+ # pip install portpicker
272+ # - name: Run benchmarks
273+ # env:
274+ # GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }}
275+ # run: |
276+ # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
277+ # cd ../../../../..
278+ # # python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
279+ # # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
280+ # # The below step just reports the success or failure of tests as a "commit status".
281+ # # This is needed for copybara integration.
282+ # - name: Run multiprocess tests
283+ # env:
284+ # TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }}
285+ # run: |
286+ # python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)"
287+ # # python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py;"
288+ # # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
289+ # # python -m pytest orbax/checkpoint/checkpoint_manager_test.py
290+ # - name: Report success or failure as github status
291+ # if: always()
292+ # shell: bash
293+ # run: |
294+ # status="${{ job.status }}"
295+ # lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
296+ # curl -sS --request POST \
297+ # --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
298+ # --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
299+ # --header 'content-type: application/json' \
300+ # --data '{
301+ # "state": "'$lowercase_status'",
302+ # "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
303+ # "description": "'$status'",
304+ # "context": "github-actions/build"
305+ # }'
0 commit comments