From db22d4f594e6ed65cadfc421ac9c087541fb8206 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Wed, 13 Aug 2025 11:00:56 -0700 Subject: [PATCH 01/22] Test multi-host runner --- .github/workflows/test.yml | 73 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 000000000..3dc5ca41f --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,73 @@ +name: build + +on: + pull_request: + branches: + - main + +permissions: + contents: read + actions: write # to cancel previous workflows + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + build-checkpoint: + name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" + runs-on: linux-g2-16-l4-1gpu-x4 + container: python:3.11 + defaults: + run: + working-directory: checkpoint + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + jax-version: ["newest"] + include: + - python-version: "3.10" + jax-version: "0.5.0" # keep in sync with minimum version in checkpoint/pyproject.toml + # TODO(b/401258175) Re-enable once JAX nightlies are fixed. + # - python-version: "3.13" + # jax-version: "nightly" + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.python-version }} + - name: Extract branch name + shell: bash + run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT + id: extract_branch + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y protobuf-compiler + + pip install tensorflow + + protoc -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto") + + pip install -e . + + pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + + if [[ "${{ matrix.jax-version }}" == "newest" ]]; then + pip install -U jax jaxlib + elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then + pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ + else + pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" + fi + - name: Test with pytest + run: | + pytest orbax/experimental/model/core/python/*_test.py + + pytest orbax/experimental/model/tf2obm/*_test.py + + pytest orbax/experimental/model/jax2obm/ \ + --ignore=orbax/experimental/model/jax2obm/main_lib_test.py \ + --ignore=orbax/experimental/model/jax2obm/sharding_test.py \ + --ignore=orbax/experimental/model/jax2obm/jax_to_polymorphic_function_test.py From fdab1b417b175e6f241d208aacbd2a132343268c Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Wed, 13 Aug 2025 23:24:17 -0700 Subject: [PATCH 02/22] Test workflow --- .github/workflows/test.yml | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3dc5ca41f..74da8422b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,28 +15,18 @@ concurrency: jobs: build-checkpoint: - name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" + name: "build-checkpoint multi-host" runs-on: linux-g2-16-l4-1gpu-x4 container: python:3.11 defaults: run: working-directory: checkpoint - strategy: - matrix: - python-version: ["3.10", "3.11", "3.12"] - jax-version: ["newest"] - include: - - python-version: "3.10" - jax-version: "0.5.0" # keep in sync with minimum version in checkpoint/pyproject.toml - # TODO(b/401258175) Re-enable once JAX nightlies are fixed. - # - python-version: "3.13" - # jax-version: "nightly" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} + - name: Set up Python 3.11 uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: - python-version: ${{ matrix.python-version }} + python-version: 3.11 - name: Extract branch name shell: bash run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT @@ -54,13 +44,7 @@ jobs: pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - if [[ "${{ matrix.jax-version }}" == "newest" ]]; then - pip install -U jax jaxlib - elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then - pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ - else - pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - fi + pip install -U jax jaxlib - name: Test with pytest run: | pytest orbax/experimental/model/core/python/*_test.py From 07dfdc31e02012f3545bec3637ece40d3ee66d15 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Wed, 13 Aug 2025 23:26:23 -0700 Subject: [PATCH 03/22] Remove setup python for now --- .github/workflows/test.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 74da8422b..663695242 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,10 +23,6 @@ jobs: working-directory: checkpoint steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python 3.11 - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: 3.11 - name: Extract branch name shell: bash run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT From 2584ecf63a50cb4cc3d94d0383689c70c2414e76 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Thu, 14 Aug 2025 09:23:22 -0700 Subject: [PATCH 04/22] Try new test --- .github/workflows/build.yml | 10 ++++--- .github/workflows/test.yml | 53 ------------------------------------- 2 files changed, 6 insertions(+), 57 deletions(-) delete mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 60a82122f..1d1e5b886 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -24,13 +24,14 @@ concurrency: jobs: build-checkpoint: name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: ubuntu-latest + runs-on: linux-x86-n2-32 + container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: run: working-directory: checkpoint strategy: matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.10"] jax-version: ["newest"] include: - python-version: "3.10" @@ -83,13 +84,14 @@ jobs: build-export: name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: ubuntu-latest + runs-on: linux-g2-16-l4-1gpu-x4 + container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: run: working-directory: export strategy: matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.10"] jax-version: ["newest"] include: - python-version: "3.10" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml deleted file mode 100644 index 663695242..000000000 --- a/.github/workflows/test.yml +++ /dev/null @@ -1,53 +0,0 @@ -name: build - -on: - pull_request: - branches: - - main - -permissions: - contents: read - actions: write # to cancel previous workflows - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - -jobs: - build-checkpoint: - name: "build-checkpoint multi-host" - runs-on: linux-g2-16-l4-1gpu-x4 - container: python:3.11 - defaults: - run: - working-directory: checkpoint - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Extract branch name - shell: bash - run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT - id: extract_branch - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y protobuf-compiler - - pip install tensorflow - - protoc -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto") - - pip install -e . - - pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - - pip install -U jax jaxlib - - name: Test with pytest - run: | - pytest orbax/experimental/model/core/python/*_test.py - - pytest orbax/experimental/model/tf2obm/*_test.py - - pytest orbax/experimental/model/jax2obm/ \ - --ignore=orbax/experimental/model/jax2obm/main_lib_test.py \ - --ignore=orbax/experimental/model/jax2obm/sharding_test.py \ - --ignore=orbax/experimental/model/jax2obm/jax_to_polymorphic_function_test.py From 4d221ccc8a31bde0a1b7e3f18f910190cd3517e5 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Thu, 14 Aug 2025 17:06:13 -0700 Subject: [PATCH 05/22] Try multi-host --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1d1e5b886..9352a181d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -24,7 +24,7 @@ concurrency: jobs: build-checkpoint: name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: linux-x86-n2-32 + runs-on: linux-g2-16-l4-1gpu-x4 container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: run: From ab0e00ea60b8afbbe2b69dcf89446c0e6171c323 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Thu, 14 Aug 2025 22:35:53 -0700 Subject: [PATCH 06/22] Try with n2-32 --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9352a181d..77705b939 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -84,7 +84,7 @@ jobs: build-export: name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: linux-g2-16-l4-1gpu-x4 + runs-on: linux-x86-n2-32 container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: run: From 42bd99e69291bcda45218585d774aae2e1927663 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Thu, 14 Aug 2025 23:09:12 -0700 Subject: [PATCH 07/22] Test 1gpu --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 77705b939..9352a181d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -84,7 +84,7 @@ jobs: build-export: name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: linux-x86-n2-32 + runs-on: linux-g2-16-l4-1gpu-x4 container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: run: From be94ac546d560413503cd92e0f60c294e3ba00af Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Thu, 14 Aug 2025 23:20:59 -0700 Subject: [PATCH 08/22] build model --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9352a181d..93d086a21 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -159,7 +159,7 @@ jobs: build-orbax-model: name: "build-orbax-model (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: ubuntu-latest + runs-on: linux-g2-16-l4-1gpu-x4 defaults: run: working-directory: model From 1a02869c5bc632311f602a4337ab57afb353a014 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Thu, 14 Aug 2025 23:25:25 -0700 Subject: [PATCH 09/22] Add container --- .github/workflows/build.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 93d086a21..b404348c5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -119,7 +119,7 @@ jobs: sudo apt-get update sudo apt-get install -y protobuf-compiler - protoc -I=. --python_out=. $(find orbax/export/ -name "*.proto") + protoc --experimental_allow_proto3_optional -I=. --python_out=. $(find orbax/export/ -name "*.proto") pip install . pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html @@ -160,6 +160,7 @@ jobs: build-orbax-model: name: "build-orbax-model (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" runs-on: linux-g2-16-l4-1gpu-x4 + container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: run: working-directory: model @@ -194,7 +195,7 @@ jobs: pip install tensorflow - protoc -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto") + protoc --experimental_allow_proto3_optional -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto") pip install -e . From 6cea85659d0d1290dc04cbc648370330992055d5 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Thu, 14 Aug 2025 23:31:47 -0700 Subject: [PATCH 10/22] Remove test --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b404348c5..1827ec8d0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -119,7 +119,7 @@ jobs: sudo apt-get update sudo apt-get install -y protobuf-compiler - protoc --experimental_allow_proto3_optional -I=. --python_out=. $(find orbax/export/ -name "*.proto") + protoc -I=. --python_out=. $(find orbax/export/ -name "*.proto") pip install . pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html @@ -195,7 +195,7 @@ jobs: pip install tensorflow - protoc --experimental_allow_proto3_optional -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto") + protoc -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto") pip install -e . From 68d6ee014e2f8ea71a96671e48480150379ba0c2 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Fri, 15 Aug 2025 10:17:09 -0700 Subject: [PATCH 11/22] Add some sleep --- .github/workflows/build.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1827ec8d0..62ab0e503 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -24,7 +24,7 @@ concurrency: jobs: build-checkpoint: name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: linux-g2-16-l4-1gpu-x4 + runs-on: linux-x86-n2-32 container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: run: @@ -84,7 +84,7 @@ jobs: build-export: name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: linux-g2-16-l4-1gpu-x4 + runs-on: linux-x86-n2-32 container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: run: @@ -110,6 +110,8 @@ jobs: with: version: '3.x' repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: sleep + run: sleep 10000 - name: Extract branch name shell: bash run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT @@ -159,7 +161,7 @@ jobs: build-orbax-model: name: "build-orbax-model (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: linux-g2-16-l4-1gpu-x4 + runs-on: linux-x86-n2-32 container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: run: From 3e9fb710fbaae14e0e34f30aefdd3cd471d3f7d7 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Fri, 15 Aug 2025 17:56:34 -0700 Subject: [PATCH 12/22] Test --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 62ab0e503..eeb9387d7 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -84,7 +84,7 @@ jobs: build-export: name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: linux-x86-n2-32 + runs-on: linux-g2-16-l4-1gpu-x4 container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: run: From 05d440196cfb360591192b2d70e13a816d13de15 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Fri, 15 Aug 2025 18:59:52 -0700 Subject: [PATCH 13/22] Remove sleep --- .github/workflows/build.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index eeb9387d7..f81d86056 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -110,8 +110,6 @@ jobs: with: version: '3.x' repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: sleep - run: sleep 10000 - name: Extract branch name shell: bash run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT From 5a4a91dd1acea256d3041f7cb4feecce782c2dd0 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Sat, 16 Aug 2025 11:12:07 -0700 Subject: [PATCH 14/22] Remove sleep --- .github/workflows/build.yml | 118 ++++++++---------------------------- 1 file changed, 24 insertions(+), 94 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f81d86056..fc0cba975 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -22,69 +22,9 @@ concurrency: cancel-in-progress: true jobs: - build-checkpoint: - name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: linux-x86-n2-32 - container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e - defaults: - run: - working-directory: checkpoint - strategy: - matrix: - python-version: ["3.10"] - jax-version: ["newest"] - include: - - python-version: "3.10" - jax-version: "0.5.0" # keep in sync with minimum version in checkpoint/pyproject.toml - # TODO(b/401258175) Re-enable once JAX nightlies are fixed. - # - python-version: "3.13" - # jax-version: "nightly" - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - # TODO(b/275613424): remove `pip install -e .` and `pip uninstall -y orbax`. - # Currently in place to override remote orbax import due to flax dependency. - run: | - pip install -e . - pip install -e .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip uninstall -y orbax - if [[ "${{ matrix.jax-version }}" == "newest" ]]; then - pip install -U jax jaxlib - elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then - pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ - else - pip install "jax>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}" - fi - - name: Test with pytest - # TODO(yaning): Move these to an exclude target within pytest.ini. - run: | - python -m pytest --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py - # The below step just reports the success or failure of tests as a "commit status". - # This is needed for copybara integration. - - name: Report success or failure as github status - if: always() - shell: bash - run: | - status="${{ job.status }}" - lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') - curl -sS --request POST \ - --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ - --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ - --header 'content-type: application/json' \ - --data '{ - "state": "'$lowercase_status'", - "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", - "description": "'$status'", - "context": "github-actions/build" - }' - - build-export: + build-export-duplicate: name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: linux-g2-16-l4-1gpu-x4 + runs-on: linux-x86-n2-32 container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: run: @@ -92,13 +32,7 @@ jobs: strategy: matrix: python-version: ["3.10"] - jax-version: ["newest"] - include: - - python-version: "3.10" - jax-version: "0.4.34" # keep in sync with minimum version in export/pyproject.toml - # TODO(b/401258175) Re-enable once JAX nightlies are fixed. - # - python-version: "3.12" # TODO(jakevdp): update to 3.13 when tf supports it. - # jax-version: "nightly" + jax-version: ["0.4.34"] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} @@ -114,6 +48,8 @@ jobs: shell: bash run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT id: extract_branch + - name: sleep + run: sleep 10000 - name: Install dependencies run: | sudo apt-get update @@ -157,22 +93,18 @@ jobs: "context": "github-actions/build" }' - build-orbax-model: - name: "build-orbax-model (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: linux-x86-n2-32 + + build-export: + name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" + runs-on: linux-g2-16-l4-1gpu-x4 container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: run: - working-directory: model + working-directory: export strategy: matrix: - python-version: ["3.10", "3.11", "3.12"] - jax-version: ["newest"] - include: - - python-version: "3.10" - jax-version: "0.5.0" # keep in sync with minimum version in experimental/model/pyproject.toml - # - python-version: "3.13" - # jax-version: "nightly" + python-version: ["3.10"] + jax-version: ["0.4.34"] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} @@ -188,19 +120,17 @@ jobs: shell: bash run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT id: extract_branch + - name: sleep + run: sleep 10000 - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y protobuf-compiler - pip install tensorflow - - protoc -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto") - - pip install -e . + protoc -I=. --python_out=. $(find orbax/export/ -name "*.proto") + pip install . pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - if [[ "${{ matrix.jax-version }}" == "newest" ]]; then pip install -U jax jaxlib elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then @@ -210,14 +140,14 @@ jobs: fi - name: Test with pytest run: | - pytest orbax/experimental/model/core/python/*_test.py - - pytest orbax/experimental/model/tf2obm/*_test.py - - pytest orbax/experimental/model/jax2obm/ \ - --ignore=orbax/experimental/model/jax2obm/main_lib_test.py \ - --ignore=orbax/experimental/model/jax2obm/sharding_test.py \ - --ignore=orbax/experimental/model/jax2obm/jax_to_polymorphic_function_test.py + test_dir=$(mktemp -d) + cp orbax/export/conftest.py ${test_dir} + for t in $(find orbax/export -maxdepth 1 -name '*_test.py'); do + cp ${t} ${test_dir} + XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest ${test_dir}/$(basename ${t}) + done + # The below step just reports the success or failure of tests as a "commit status". + # This is needed for copybara integration. - name: Report success or failure as github status if: always() shell: bash From 650e8eb022b10dfdd62125820ad02d31293d4a3f Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Mon, 18 Aug 2025 10:14:28 -0700 Subject: [PATCH 15/22] Add github workflow --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fc0cba975..e0b769b4b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -23,7 +23,7 @@ concurrency: jobs: build-export-duplicate: - name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" + name: "build-export n2 (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" runs-on: linux-x86-n2-32 container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: @@ -95,7 +95,7 @@ jobs: build-export: - name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" + name: "build-export g2 (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" runs-on: linux-g2-16-l4-1gpu-x4 container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: From eb82e63561038750ea96adc739056a424e287604 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Mon, 18 Aug 2025 10:56:46 -0700 Subject: [PATCH 16/22] move sleep to end --- .github/workflows/build.yml | 65 ++++++++++--------------------------- 1 file changed, 18 insertions(+), 47 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e0b769b4b..38179a948 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -48,8 +48,6 @@ jobs: shell: bash run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT id: extract_branch - - name: sleep - run: sleep 10000 - name: Install dependencies run: | sudo apt-get update @@ -66,32 +64,20 @@ jobs: else pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" fi - - name: Test with pytest - run: | - test_dir=$(mktemp -d) - cp orbax/export/conftest.py ${test_dir} - for t in $(find orbax/export -maxdepth 1 -name '*_test.py'); do - cp ${t} ${test_dir} - XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest ${test_dir}/$(basename ${t}) - done +# - name: Test with pytest +# run: | +# test_dir=$(mktemp -d) +# cp orbax/export/conftest.py ${test_dir} +# for t in $(find orbax/export -maxdepth 1 -name '*_test.py'); do +# cp ${t} ${test_dir} +# XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest ${test_dir}/$(basename ${t}) +# done # The below step just reports the success or failure of tests as a "commit status". # This is needed for copybara integration. - name: Report success or failure as github status if: always() shell: bash - run: | - status="${{ job.status }}" - lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') - curl -sS --request POST \ - --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ - --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ - --header 'content-type: application/json' \ - --data '{ - "state": "'$lowercase_status'", - "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", - "description": "'$status'", - "context": "github-actions/build" - }' + run: sleep 10000 build-export: @@ -120,8 +106,6 @@ jobs: shell: bash run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT id: extract_branch - - name: sleep - run: sleep 10000 - name: Install dependencies run: | sudo apt-get update @@ -138,29 +122,16 @@ jobs: else pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" fi - - name: Test with pytest - run: | - test_dir=$(mktemp -d) - cp orbax/export/conftest.py ${test_dir} - for t in $(find orbax/export -maxdepth 1 -name '*_test.py'); do - cp ${t} ${test_dir} - XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest ${test_dir}/$(basename ${t}) - done +# - name: Test with pytest +# run: | +# test_dir=$(mktemp -d) +# cp orbax/export/conftest.py ${test_dir} +# for t in $(find orbax/export -maxdepth 1 -name '*_test.py'); do +# cp ${t} ${test_dir} +# XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest ${test_dir}/$(basename ${t}) +# done # The below step just reports the success or failure of tests as a "commit status". # This is needed for copybara integration. - name: Report success or failure as github status if: always() - shell: bash - run: | - status="${{ job.status }}" - lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') - curl -sS --request POST \ - --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ - --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ - --header 'content-type: application/json' \ - --data '{ - "state": "'$lowercase_status'", - "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", - "description": "'$status'", - "context": "github-actions/build" - }' + run: sleep 10000 \ No newline at end of file From 60696380d3976a6abc97554cfc4ed998bfa077ed Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Tue, 19 Aug 2025 16:39:25 -0700 Subject: [PATCH 17/22] Try workflow --- .github/workflows/build.yml | 167 ++++++++++++++++++++++++++++++++---- 1 file changed, 151 insertions(+), 16 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 38179a948..b9d202159 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -22,16 +22,70 @@ concurrency: cancel-in-progress: true jobs: - build-export-duplicate: - name: "build-export n2 (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: linux-x86-n2-32 + build-checkpoint: + name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" + runs-on: linux-g2-16-l4-1gpu-x4 + container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e + defaults: + run: + working-directory: checkpoint + strategy: + matrix: + python-version: ["3.12"] + jax-version: ["0.5.0"] + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + # TODO(b/275613424): remove `pip install -e .` and `pip uninstall -y orbax`. + # Currently in place to override remote orbax import due to flax dependency. + run: | + pip install -e . + pip install -e .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + pip uninstall -y orbax + if [[ "${{ matrix.jax-version }}" == "newest" ]]; then + pip install -U jax jaxlib + elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then + pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ + else + pip install "jax>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}" + fi + - name: Test with pytest + # TODO(yaning): Move these to an exclude target within pytest.ini. + run: | + python -m pytest --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py + # The below step just reports the success or failure of tests as a "commit status". + # This is needed for copybara integration. + - name: Report success or failure as github status + if: always() + shell: bash + run: | + status="${{ job.status }}" + lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') + curl -sS --request POST \ + --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ + --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ + --header 'content-type: application/json' \ + --data '{ + "state": "'$lowercase_status'", + "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", + "description": "'$status'", + "context": "github-actions/build" + }' + + build-export: + name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" + runs-on: linux-g2-16-l4-1gpu-x4 container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: run: working-directory: export strategy: matrix: - python-version: ["3.10"] + python-version: ["3.12"] jax-version: ["0.4.34"] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -64,26 +118,107 @@ jobs: else pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" fi -# - name: Test with pytest -# run: | -# test_dir=$(mktemp -d) -# cp orbax/export/conftest.py ${test_dir} -# for t in $(find orbax/export -maxdepth 1 -name '*_test.py'); do -# cp ${t} ${test_dir} -# XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest ${test_dir}/$(basename ${t}) -# done + - name: Test with pytest + run: | + test_dir=$(mktemp -d) + cp orbax/export/conftest.py ${test_dir} + for t in $(find orbax/export -maxdepth 1 -name '*_test.py'); do + cp ${t} ${test_dir} + XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest ${test_dir}/$(basename ${t}) + done # The below step just reports the success or failure of tests as a "commit status". # This is needed for copybara integration. - name: Report success or failure as github status if: always() shell: bash - run: sleep 10000 - + run: | + status="${{ job.status }}" + lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') + curl -sS --request POST \ + --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ + --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ + --header 'content-type: application/json' \ + --data '{ + "state": "'$lowercase_status'", + "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", + "description": "'$status'", + "context": "github-actions/build" + }' - build-export: - name: "build-export g2 (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" + build-orbax-model: + name: "build-orbax-model (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" runs-on: linux-g2-16-l4-1gpu-x4 container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e + defaults: + run: + working-directory: model + strategy: + matrix: + python-version: ["3.12"] + jax-version: ["0.5.0"] + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.python-version }} + - name: Install Protoc + uses: arduino/setup-protoc@v1 + with: + version: '3.x' + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Extract branch name + shell: bash + run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT + id: extract_branch + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y protobuf-compiler + + pip install tensorflow + + protoc -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto") + + pip install -e . + + pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + + if [[ "${{ matrix.jax-version }}" == "newest" ]]; then + pip install -U jax jaxlib + elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then + pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ + else + pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" + fi + - name: Test with pytest + run: | + pytest orbax/experimental/model/core/python/*_test.py + + pytest orbax/experimental/model/tf2obm/*_test.py + + pytest orbax/experimental/model/jax2obm/ \ + --ignore=orbax/experimental/model/jax2obm/main_lib_test.py \ + --ignore=orbax/experimental/model/jax2obm/sharding_test.py \ + --ignore=orbax/experimental/model/jax2obm/jax_to_polymorphic_function_test.py + - name: Report success or failure as github status + if: always() + shell: bash + run: | + status="${{ job.status }}" + lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') + curl -sS --request POST \ + --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ + --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ + --header 'content-type: application/json' \ + --data '{ + "state": "'$lowercase_status'", + "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", + "description": "'$status'", + "context": "github-actions/build" + }' + + name: "build-export g2 (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" defaults: run: working-directory: export From 62573328b615bdcfedaf221abbaab65e9682c986 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Tue, 19 Aug 2025 16:42:15 -0700 Subject: [PATCH 18/22] Fix build --- .github/workflows/build.yml | 53 ------------------------------------- 1 file changed, 53 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b9d202159..06b8dd01b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -217,56 +217,3 @@ jobs: "description": "'$status'", "context": "github-actions/build" }' - - name: "build-export g2 (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - defaults: - run: - working-directory: export - strategy: - matrix: - python-version: ["3.10"] - jax-version: ["0.4.34"] - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install Protoc - uses: arduino/setup-protoc@v1 - with: - version: '3.x' - repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Extract branch name - shell: bash - run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT - id: extract_branch - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y protobuf-compiler - - protoc -I=. --python_out=. $(find orbax/export/ -name "*.proto") - - pip install . - pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - if [[ "${{ matrix.jax-version }}" == "newest" ]]; then - pip install -U jax jaxlib - elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then - pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ - else - pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - fi -# - name: Test with pytest -# run: | -# test_dir=$(mktemp -d) -# cp orbax/export/conftest.py ${test_dir} -# for t in $(find orbax/export -maxdepth 1 -name '*_test.py'); do -# cp ${t} ${test_dir} -# XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest ${test_dir}/$(basename ${t}) -# done - # The below step just reports the success or failure of tests as a "commit status". - # This is needed for copybara integration. - - name: Report success or failure as github status - if: always() - run: sleep 10000 \ No newline at end of file From a6cab4d77e55e42968e95399adfc35f0242f6068 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Thu, 16 Oct 2025 12:33:55 -0700 Subject: [PATCH 19/22] Refactor GitHub Actions build workflow Updated build workflow to include multiprocess checkpoint benchmarks and modified JAX installation commands. Signed-off-by: Quoc Truong --- .github/workflows/build.yml | 187 +++--------------------------------- 1 file changed, 12 insertions(+), 175 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 06b8dd01b..a4425f860 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,5 +1,4 @@ name: build - on: # continuous schedule: @@ -12,18 +11,15 @@ on: pull_request: branches: - main - permissions: contents: read actions: write # to cancel previous workflows - concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true - jobs: - build-checkpoint: - name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" + multiprocess-checkpoint-benchmarks: + name: "multiprocess-checkpoint-benchmarks (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" runs-on: linux-g2-16-l4-1gpu-x4 container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e defaults: @@ -40,180 +36,21 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install dependencies - # TODO(b/275613424): remove `pip install -e .` and `pip uninstall -y orbax`. - # Currently in place to override remote orbax import due to flax dependency. run: | pip install -e . - pip install -e .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip uninstall -y orbax if [[ "${{ matrix.jax-version }}" == "newest" ]]; then - pip install -U jax jaxlib + pip install -U jax[k8s,cuda12] jaxlib elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then - pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ + pip install -U --pre jax[k8s,cuda12] jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ else - pip install "jax>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}" + pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}" fi - - name: Test with pytest - # TODO(yaning): Move these to an exclude target within pytest.ini. + pip install gcsfs + pip install portpicker + - name: Run benchmarks + env: + GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }} run: | - python -m pytest --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py - # The below step just reports the success or failure of tests as a "commit status". - # This is needed for copybara integration. - - name: Report success or failure as github status - if: always() - shell: bash - run: | - status="${{ job.status }}" - lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') - curl -sS --request POST \ - --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ - --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ - --header 'content-type: application/json' \ - --data '{ - "state": "'$lowercase_status'", - "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", - "description": "'$status'", - "context": "github-actions/build" - }' - - build-export: - name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: linux-g2-16-l4-1gpu-x4 - container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e - defaults: - run: - working-directory: export - strategy: - matrix: - python-version: ["3.12"] - jax-version: ["0.4.34"] - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install Protoc - uses: arduino/setup-protoc@v1 - with: - version: '3.x' - repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Extract branch name - shell: bash - run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT - id: extract_branch - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y protobuf-compiler - - protoc -I=. --python_out=. $(find orbax/export/ -name "*.proto") - - pip install . - pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - if [[ "${{ matrix.jax-version }}" == "newest" ]]; then - pip install -U jax jaxlib - elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then - pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ - else - pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - fi - - name: Test with pytest - run: | - test_dir=$(mktemp -d) - cp orbax/export/conftest.py ${test_dir} - for t in $(find orbax/export -maxdepth 1 -name '*_test.py'); do - cp ${t} ${test_dir} - XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest ${test_dir}/$(basename ${t}) - done - # The below step just reports the success or failure of tests as a "commit status". - # This is needed for copybara integration. - - name: Report success or failure as github status - if: always() - shell: bash - run: | - status="${{ job.status }}" - lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') - curl -sS --request POST \ - --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ - --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ - --header 'content-type: application/json' \ - --data '{ - "state": "'$lowercase_status'", - "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", - "description": "'$status'", - "context": "github-actions/build" - }' - - build-orbax-model: - name: "build-orbax-model (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - runs-on: linux-g2-16-l4-1gpu-x4 - container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e - defaults: - run: - working-directory: model - strategy: - matrix: - python-version: ["3.12"] - jax-version: ["0.5.0"] - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - name: Install Protoc - uses: arduino/setup-protoc@v1 - with: - version: '3.x' - repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Extract branch name - shell: bash - run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT - id: extract_branch - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y protobuf-compiler - - pip install tensorflow - - protoc -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto") - - pip install -e . - - pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - - if [[ "${{ matrix.jax-version }}" == "newest" ]]; then - pip install -U jax jaxlib - elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then - pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ - else - pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" - fi - - name: Test with pytest - run: | - pytest orbax/experimental/model/core/python/*_test.py - - pytest orbax/experimental/model/tf2obm/*_test.py - - pytest orbax/experimental/model/jax2obm/ \ - --ignore=orbax/experimental/model/jax2obm/main_lib_test.py \ - --ignore=orbax/experimental/model/jax2obm/sharding_test.py \ - --ignore=orbax/experimental/model/jax2obm/jax_to_polymorphic_function_test.py - - name: Report success or failure as github status - if: always() - shell: bash - run: | - status="${{ job.status }}" - lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') - curl -sS --request POST \ - --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ - --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ - --header 'content-type: application/json' \ - --data '{ - "state": "'$lowercase_status'", - "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", - "description": "'$status'", - "context": "github-actions/build" - }' + python3 -c "import jax; jax.distributed.initialize(); print(jax.devices());" From 999440dab5997c225713a8f409ba310ef1e43925 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Thu, 16 Oct 2025 13:29:59 -0700 Subject: [PATCH 20/22] Update JAX version in build workflow Signed-off-by: Quoc Truong --- .github/workflows/build.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ab2838c1d..af79a386b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -28,10 +28,7 @@ jobs: strategy: matrix: python-version: ["3.12"] - jax-version: ["0.5.0"] - include: - - python-version: "3.10" - jax-version: "0.6.0" # keep in sync with minimum version in checkpoint/pyproject.toml + jax-version: ["0.8.0"] # TODO(b/401258175) Re-enable once JAX nightlies are fixed. # - python-version: "3.13" # jax-version: "nightly" From a0ddc44f53a99bd232c7ba886a004087079f18fe Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Thu, 16 Oct 2025 14:32:05 -0700 Subject: [PATCH 21/22] Change Python version in build workflow to 3.10 Signed-off-by: Quoc Truong --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index af79a386b..5a2f4235f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -27,7 +27,7 @@ jobs: working-directory: checkpoint strategy: matrix: - python-version: ["3.12"] + python-version: ["3.10"] jax-version: ["0.8.0"] # TODO(b/401258175) Re-enable once JAX nightlies are fixed. # - python-version: "3.13" From f1abd52a8d15c1f19c16a411421fcfcded20d8a6 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Thu, 16 Oct 2025 14:34:24 -0700 Subject: [PATCH 22/22] Update build.yml Signed-off-by: Quoc Truong --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5a2f4235f..ac3a752db 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -28,7 +28,7 @@ jobs: strategy: matrix: python-version: ["3.10"] - jax-version: ["0.8.0"] + jax-version: ["0.6.0"] # TODO(b/401258175) Re-enable once JAX nightlies are fixed. # - python-version: "3.13" # jax-version: "nightly"