Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 16 additions & 190 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
name: build

on:
# continuous
schedule:
Expand All @@ -12,29 +11,24 @@ on:
pull_request:
branches:
- main

permissions:
contents: read
actions: write # to cancel previous workflows

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

jobs:
build-checkpoint:
name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
runs-on: ubuntu-latest
multiprocess-checkpoint-benchmarks:
name: "multiprocess-checkpoint-benchmarks (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
runs-on: linux-g2-16-l4-1gpu-x4
container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e
defaults:
run:
working-directory: checkpoint
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
jax-version: ["newest"]
include:
- python-version: "3.10"
jax-version: "0.6.0" # keep in sync with minimum version in checkpoint/pyproject.toml
python-version: ["3.10"]
jax-version: ["0.6.0"]
# TODO(b/401258175) Re-enable once JAX nightlies are fixed.
# - python-version: "3.13"
# jax-version: "nightly"
Expand All @@ -45,189 +39,21 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
# TODO(b/275613424): remove `pip install -e .` and `pip uninstall -y orbax`.
# Currently in place to override remote orbax import due to flax dependency.
run: |
pip install -e .
pip install -e .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip uninstall -y orbax
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
pip install -U jax jaxlib
pip install -U jax[k8s,cuda12] jaxlib
elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
pip install -U --pre jax[k8s,cuda12] jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
else
pip install "jax>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
fi
- name: Test with pytest
# TODO(yaning): Move these to an exclude target within pytest.ini.
pip install gcsfs
pip install portpicker
- name: Run benchmarks
env:
GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }}
run: |
python -m pytest --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py
# The below step just reports the success or failure of tests as a "commit status".
# This is needed for copybara integration.
- name: Report success or failure as github status
if: always()
shell: bash
run: |
status="${{ job.status }}"
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
curl -sS --request POST \
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
--header 'content-type: application/json' \
--data '{
"state": "'$lowercase_status'",
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
"description": "'$status'",
"context": "github-actions/build"
}'

build-export:
name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
runs-on: ubuntu-latest
defaults:
run:
working-directory: export
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
jax-version: ["newest"]
include:
- python-version: "3.10"
jax-version: "0.4.34" # keep in sync with minimum version in export/pyproject.toml
# TODO(b/401258175) Re-enable once JAX nightlies are fixed.
# - python-version: "3.12" # TODO(jakevdp): update to 3.13 when tf supports it.
# jax-version: "nightly"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install Protoc
uses: arduino/setup-protoc@v1
with:
version: '3.x'
repo-token: ${{ secrets.GITHUB_TOKEN }}
- name: Extract branch name
shell: bash
run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT
id: extract_branch
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y protobuf-compiler

protoc -I=. --python_out=. $(find orbax/export/ -name "*.proto")

pip install .
pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
pip install -U jax jaxlib
elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
else
pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
fi
- name: Test with pytest
run: |
test_dir=$(mktemp -d)
cp orbax/export/conftest.py ${test_dir}
for t in $(find orbax/export -maxdepth 1 -name '*_test.py'); do
cp ${t} ${test_dir}
XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest ${test_dir}/$(basename ${t})
done
# The below step just reports the success or failure of tests as a "commit status".
# This is needed for copybara integration.
- name: Report success or failure as github status
if: always()
shell: bash
run: |
status="${{ job.status }}"
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
curl -sS --request POST \
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
--header 'content-type: application/json' \
--data '{
"state": "'$lowercase_status'",
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
"description": "'$status'",
"context": "github-actions/build"
}'

build-orbax-model:
name: "build-orbax-model (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
runs-on: ubuntu-latest
defaults:
run:
working-directory: model
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
jax-version: ["newest"]
include:
- python-version: "3.10"
jax-version: "0.5.0" # keep in sync with minimum version in experimental/model/pyproject.toml
# - python-version: "3.13"
# jax-version: "nightly"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install Protoc
uses: arduino/setup-protoc@v1
with:
version: '3.x'
repo-token: ${{ secrets.GITHUB_TOKEN }}
- name: Extract branch name
shell: bash
run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT
id: extract_branch
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y protobuf-compiler

pip install tensorflow

protoc -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto")

pip install -e .

pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
pip install -U jax jaxlib
elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
else
pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
fi
- name: Test with pytest
run: |
pytest orbax/experimental/model/core/python/*_test.py

pytest orbax/experimental/model/tf2obm/*_test.py

pytest orbax/experimental/model/jax2obm/ \
--ignore=orbax/experimental/model/jax2obm/main_lib_test.py \
--ignore=orbax/experimental/model/jax2obm/sharding_test.py \
--ignore=orbax/experimental/model/jax2obm/jax_to_polymorphic_function_test.py
- name: Report success or failure as github status
if: always()
shell: bash
run: |
status="${{ job.status }}"
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
curl -sS --request POST \
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
--header 'content-type: application/json' \
--data '{
"state": "'$lowercase_status'",
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
"description": "'$status'",
"context": "github-actions/build"
}'
python3 -c "import jax; jax.distributed.initialize(); print(jax.devices());"
Loading