OSS async_utils_test.py #37374
Workflow file for this run
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: build | |
| on: | |
| # continuous | |
| schedule: | |
| # Run every hour | |
| - cron: "0 * * * *" | |
| push: | |
| branches: | |
| - main | |
| - 'test_*' | |
| pull_request: | |
| branches: | |
| - main | |
| permissions: | |
| contents: read | |
| actions: write # to cancel previous workflows | |
| concurrency: | |
| group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | |
| cancel-in-progress: true | |
| jobs: | |
| build-checkpoint: | |
| name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" | |
| runs-on: ubuntu-latest | |
| defaults: | |
| run: | |
| working-directory: checkpoint | |
| strategy: | |
| matrix: | |
| python-version: ["3.10", "3.11", "3.12"] | |
| jax-version: ["newest"] | |
| include: | |
| - python-version: "3.10" | |
| jax-version: "0.6.0" # keep in sync with minimum version in checkpoint/pyproject.toml | |
| # TODO(b/401258175) Re-enable once JAX nightlies are fixed. | |
| # - python-version: "3.13" | |
| # jax-version: "nightly" | |
| steps: | |
| - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | |
| - name: Set up Python ${{ matrix.python-version }} | |
| uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 | |
| with: | |
| python-version: ${{ matrix.python-version }} | |
| - name: Free up space | |
| run: | | |
| pip cache purge | |
| sudo docker image prune -a -f | |
| df -h | |
| - name: Install dependencies | |
| # TODO(b/275613424): remove `pip install -e .` and `pip uninstall -y orbax`. | |
| # Currently in place to override remote orbax import due to flax dependency. | |
| run: | | |
| pip install --no-cache-dir -e . | |
| pip install --no-cache-dir -e .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | |
| pip uninstall -y orbax | |
| if [[ "${{ matrix.jax-version }}" == "newest" ]]; then | |
| pip install -U jax jaxlib | |
| elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then | |
| pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ | |
| else | |
| pip install "jax>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}" | |
| fi | |
| - name: Test with pytest | |
| # TODO(yaning): Move these to an exclude target within pytest.ini. | |
| run: | | |
| python -m pytest --import-mode=importlib --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py --ignore=orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py --ignore=orbax/checkpoint/checkpoint_manager_test.py | |
| # The below step just reports the success or failure of tests as a "commit status". | |
| # This is needed for copybara integration. | |
| - name: Report success or failure as github status | |
| if: always() | |
| shell: bash | |
| run: | | |
| status="${{ job.status }}" | |
| lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') | |
| curl -sS --request POST \ | |
| --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ | |
| --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ | |
| --header 'content-type: application/json' \ | |
| --data '{ | |
| "state": "'$lowercase_status'", | |
| "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", | |
| "description": "'$status'", | |
| "context": "github-actions/build" | |
| }' | |
| # build-export: | |
| # name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" | |
| # runs-on: ubuntu-latest | |
| # defaults: | |
| # run: | |
| # working-directory: export | |
| # strategy: | |
| # matrix: | |
| # python-version: ["3.10", "3.11", "3.12"] | |
| # jax-version: ["newest"] | |
| # include: | |
| # - python-version: "3.10" | |
| # jax-version: "0.4.34" # keep in sync with minimum version in export/pyproject.toml | |
| # # TODO(b/401258175) Re-enable once JAX nightlies are fixed. | |
| # # - python-version: "3.12" # TODO(jakevdp): update to 3.13 when tf supports it. | |
| # # jax-version: "nightly" | |
| # steps: | |
| # - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | |
| # - name: Set up Python ${{ matrix.python-version }} | |
| # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 | |
| # with: | |
| # python-version: ${{ matrix.python-version }} | |
| # - name: Install Protoc | |
| # uses: arduino/setup-protoc@v1 | |
| # with: | |
| # version: '3.x' | |
| # repo-token: ${{ secrets.GITHUB_TOKEN }} | |
| # - name: Extract branch name | |
| # shell: bash | |
| # run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT | |
| # id: extract_branch | |
| # - name: Install dependencies | |
| # run: | | |
| # sudo apt-get update | |
| # sudo apt-get install -y protobuf-compiler | |
| # protoc -I=. --python_out=. $(find orbax/export/ -name "*.proto") | |
| # pip install . | |
| # pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | |
| # if [[ "${{ matrix.jax-version }}" == "newest" ]]; then | |
| # pip install -U jax jaxlib | |
| # elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then | |
| # pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ | |
| # else | |
| # pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" | |
| # fi | |
| # - name: Test with pytest | |
| # run: | | |
| # test_dir=$(mktemp -d) | |
| # cp orbax/export/conftest.py ${test_dir} | |
| # for t in $(find orbax/export -maxdepth 1 -name '*_test.py'); do | |
| # cp ${t} ${test_dir} | |
| # XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest ${test_dir}/$(basename ${t}) | |
| # done | |
| # # The below step just reports the success or failure of tests as a "commit status". | |
| # # This is needed for copybara integration. | |
| # - name: Report success or failure as github status | |
| # if: always() | |
| # shell: bash | |
| # run: | | |
| # status="${{ job.status }}" | |
| # lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') | |
| # curl -sS --request POST \ | |
| # --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ | |
| # --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ | |
| # --header 'content-type: application/json' \ | |
| # --data '{ | |
| # "state": "'$lowercase_status'", | |
| # "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", | |
| # "description": "'$status'", | |
| # "context": "github-actions/build" | |
| # }' | |
| # build-orbax-model: | |
| # name: "build-orbax-model (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" | |
| # runs-on: ubuntu-latest | |
| # defaults: | |
| # run: | |
| # working-directory: model | |
| # strategy: | |
| # matrix: | |
| # python-version: ["3.10", "3.11", "3.12"] | |
| # jax-version: ["newest"] | |
| # include: | |
| # - python-version: "3.10" | |
| # jax-version: "0.5.0" # keep in sync with minimum version in experimental/model/pyproject.toml | |
| # # - python-version: "3.13" | |
| # # jax-version: "nightly" | |
| # steps: | |
| # - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | |
| # - name: Set up Python ${{ matrix.python-version }} | |
| # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 | |
| # with: | |
| # python-version: ${{ matrix.python-version }} | |
| # - name: Install Protoc | |
| # uses: arduino/setup-protoc@v1 | |
| # with: | |
| # version: '3.x' | |
| # repo-token: ${{ secrets.GITHUB_TOKEN }} | |
| # - name: Extract branch name | |
| # shell: bash | |
| # run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT | |
| # id: extract_branch | |
| # - name: Install dependencies | |
| # run: | | |
| # sudo apt-get update | |
| # sudo apt-get install -y protobuf-compiler | |
| # pip install tensorflow | |
| # protoc -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto") | |
| # pip install -e . | |
| # pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | |
| # if [[ "${{ matrix.jax-version }}" == "newest" ]]; then | |
| # pip install -U jax jaxlib | |
| # elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then | |
| # pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ | |
| # else | |
| # pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" | |
| # fi | |
| # - name: Test with pytest | |
| # run: | | |
| # pytest orbax/experimental/model/core/python/*_test.py | |
| # pytest orbax/experimental/model/tf2obm/*_test.py | |
| # pytest orbax/experimental/model/jax2obm/ \ | |
| # --ignore=orbax/experimental/model/jax2obm/main_lib_test.py \ | |
| # --ignore=orbax/experimental/model/jax2obm/sharding_test.py \ | |
| # --ignore=orbax/experimental/model/jax2obm/jax_to_polymorphic_function_test.py | |
| # - name: Report success or failure as github status | |
| # if: always() | |
| # shell: bash | |
| # run: | | |
| # status="${{ job.status }}" | |
| # lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') | |
| # curl -sS --request POST \ | |
| # --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ | |
| # --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ | |
| # --header 'content-type: application/json' \ | |
| # --data '{ | |
| # "state": "'$lowercase_status'", | |
| # "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", | |
| # "description": "'$status'", | |
| # "context": "github-actions/build" | |
| # }' | |
| # multiprocess-checkpoint-benchmarks: | |
| # name: "multiprocess-checkpoint-benchmarks (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" | |
| # runs-on: linux-g2-16-l4-1gpu-x4 | |
| # # runs-on: linux-x86-ct5lp-4tpu-x4 | |
| # container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e | |
| # defaults: | |
| # run: | |
| # working-directory: checkpoint | |
| # strategy: | |
| # matrix: | |
| # python-version: ["3.10", "3.11", "3.12"] | |
| # jax-version: ["0.6.0"] | |
| # steps: | |
| # - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | |
| # - name: Set up Python ${{ matrix.python-version }} | |
| # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 | |
| # with: | |
| # python-version: ${{ matrix.python-version }} | |
| # - name: Install dependencies | |
| # run: | | |
| # pip install -e . | |
| # pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | |
| # pip uninstall -y orbax | |
| # if [[ "${{ matrix.jax-version }}" == "newest" ]]; then | |
| # pip install -U jax[k8s,cuda12] jaxlib | |
| # elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then | |
| # pip install -U --pre jax[k8s,cuda12] jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ | |
| # else | |
| # pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}" | |
| # fi | |
| # pip install gcsfs | |
| # pip install portpicker | |
| # - name: Run benchmarks | |
| # env: | |
| # GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }} | |
| # run: | | |
| # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH | |
| # cd ../../../../.. | |
| # # python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py | |
| # # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH | |
| # # The below step just reports the success or failure of tests as a "commit status". | |
| # # This is needed for copybara integration. | |
| # - name: Run multiprocess tests | |
| # env: | |
| # TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }} | |
| # run: | | |
| # python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)" | |
| # # python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py;" | |
| # # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH | |
| # # python -m pytest orbax/checkpoint/checkpoint_manager_test.py | |
| # - name: Report success or failure as github status | |
| # if: always() | |
| # shell: bash | |
| # run: | | |
| # status="${{ job.status }}" | |
| # lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') | |
| # curl -sS --request POST \ | |
| # --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ | |
| # --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ | |
| # --header 'content-type: application/json' \ | |
| # --data '{ | |
| # "state": "'$lowercase_status'", | |
| # "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", | |
| # "description": "'$status'", | |
| # "context": "github-actions/build" | |
| # }' |