Skip to content

OSS async_utils_test.py and use operationId generator instead of passing operation id around. #37375

OSS async_utils_test.py and use operationId generator instead of passing operation id around.

OSS async_utils_test.py and use operationId generator instead of passing operation id around. #37375

Workflow file for this run

name: build
on:
# continuous
schedule:
# Run every hour
- cron: "0 * * * *"
push:
branches:
- main
- 'test_*'
pull_request:
branches:
- main
permissions:
contents: read
actions: write # to cancel previous workflows
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
build-checkpoint:
name: "build-checkpoint (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
runs-on: ubuntu-latest
defaults:
run:
working-directory: checkpoint
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
jax-version: ["newest"]
include:
- python-version: "3.10"
jax-version: "0.6.0" # keep in sync with minimum version in checkpoint/pyproject.toml
# TODO(b/401258175) Re-enable once JAX nightlies are fixed.
# - python-version: "3.13"
# jax-version: "nightly"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Free up space
run: |
pip cache purge
sudo docker image prune -a -f
df -h
- name: Install dependencies
# TODO(b/275613424): remove `pip install -e .` and `pip uninstall -y orbax`.
# Currently in place to override remote orbax import due to flax dependency.
run: |
pip install --no-cache-dir -e .
pip install --no-cache-dir -e .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip uninstall -y orbax
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
pip install -U jax jaxlib
elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
else
pip install "jax>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
fi
- name: Test with pytest
# TODO(yaning): Move these to an exclude target within pytest.ini.
run: |
python -m pytest --import-mode=importlib --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py --ignore=orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py --ignore=orbax/checkpoint/checkpoint_manager_test.py
# The below step just reports the success or failure of tests as a "commit status".
# This is needed for copybara integration.
- name: Report success or failure as github status
if: always()
shell: bash
run: |
status="${{ job.status }}"
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
curl -sS --request POST \
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
--header 'content-type: application/json' \
--data '{
"state": "'$lowercase_status'",
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
"description": "'$status'",
"context": "github-actions/build"
}'
# build-export:
# name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
# runs-on: ubuntu-latest
# defaults:
# run:
# working-directory: export
# strategy:
# matrix:
# python-version: ["3.10", "3.11", "3.12"]
# jax-version: ["newest"]
# include:
# - python-version: "3.10"
# jax-version: "0.4.34" # keep in sync with minimum version in export/pyproject.toml
# # TODO(b/401258175) Re-enable once JAX nightlies are fixed.
# # - python-version: "3.12" # TODO(jakevdp): update to 3.13 when tf supports it.
# # jax-version: "nightly"
# steps:
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
# - name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install Protoc
# uses: arduino/setup-protoc@v1
# with:
# version: '3.x'
# repo-token: ${{ secrets.GITHUB_TOKEN }}
# - name: Extract branch name
# shell: bash
# run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT
# id: extract_branch
# - name: Install dependencies
# run: |
# sudo apt-get update
# sudo apt-get install -y protobuf-compiler
# protoc -I=. --python_out=. $(find orbax/export/ -name "*.proto")
# pip install .
# pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
# pip install -U jax jaxlib
# elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
# pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
# else
# pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
# fi
# - name: Test with pytest
# run: |
# test_dir=$(mktemp -d)
# cp orbax/export/conftest.py ${test_dir}
# for t in $(find orbax/export -maxdepth 1 -name '*_test.py'); do
# cp ${t} ${test_dir}
# XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest ${test_dir}/$(basename ${t})
# done
# # The below step just reports the success or failure of tests as a "commit status".
# # This is needed for copybara integration.
# - name: Report success or failure as github status
# if: always()
# shell: bash
# run: |
# status="${{ job.status }}"
# lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
# curl -sS --request POST \
# --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
# --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
# --header 'content-type: application/json' \
# --data '{
# "state": "'$lowercase_status'",
# "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
# "description": "'$status'",
# "context": "github-actions/build"
# }'
# build-orbax-model:
# name: "build-orbax-model (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
# runs-on: ubuntu-latest
# defaults:
# run:
# working-directory: model
# strategy:
# matrix:
# python-version: ["3.10", "3.11", "3.12"]
# jax-version: ["newest"]
# include:
# - python-version: "3.10"
# jax-version: "0.5.0" # keep in sync with minimum version in experimental/model/pyproject.toml
# # - python-version: "3.13"
# # jax-version: "nightly"
# steps:
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
# - name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install Protoc
# uses: arduino/setup-protoc@v1
# with:
# version: '3.x'
# repo-token: ${{ secrets.GITHUB_TOKEN }}
# - name: Extract branch name
# shell: bash
# run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT
# id: extract_branch
# - name: Install dependencies
# run: |
# sudo apt-get update
# sudo apt-get install -y protobuf-compiler
# pip install tensorflow
# protoc -I=. --python_out=. $(find orbax/experimental/model/ -name "*.proto")
# pip install -e .
# pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
# pip install -U jax jaxlib
# elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
# pip install -U --pre jax jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
# else
# pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
# fi
# - name: Test with pytest
# run: |
# pytest orbax/experimental/model/core/python/*_test.py
# pytest orbax/experimental/model/tf2obm/*_test.py
# pytest orbax/experimental/model/jax2obm/ \
# --ignore=orbax/experimental/model/jax2obm/main_lib_test.py \
# --ignore=orbax/experimental/model/jax2obm/sharding_test.py \
# --ignore=orbax/experimental/model/jax2obm/jax_to_polymorphic_function_test.py
# - name: Report success or failure as github status
# if: always()
# shell: bash
# run: |
# status="${{ job.status }}"
# lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
# curl -sS --request POST \
# --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
# --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
# --header 'content-type: application/json' \
# --data '{
# "state": "'$lowercase_status'",
# "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
# "description": "'$status'",
# "context": "github-actions/build"
# }'
# multiprocess-checkpoint-benchmarks:
# name: "multiprocess-checkpoint-benchmarks (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
# runs-on: linux-g2-16-l4-1gpu-x4
# # runs-on: linux-x86-ct5lp-4tpu-x4
# container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e
# defaults:
# run:
# working-directory: checkpoint
# strategy:
# matrix:
# python-version: ["3.10", "3.11", "3.12"]
# jax-version: ["0.6.0"]
# steps:
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
# - name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
# run: |
# pip install -e .
# pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# pip uninstall -y orbax
# if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
# pip install -U jax[k8s,cuda12] jaxlib
# elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
# pip install -U --pre jax[k8s,cuda12] jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
# else
# pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
# fi
# pip install gcsfs
# pip install portpicker
# - name: Run benchmarks
# env:
# GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }}
# run: |
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
# cd ../../../../..
# # python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
# # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
# # The below step just reports the success or failure of tests as a "commit status".
# # This is needed for copybara integration.
# - name: Run multiprocess tests
# env:
# TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }}
# run: |
# python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)"
# # python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py;"
# # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
# # python -m pytest orbax/checkpoint/checkpoint_manager_test.py
# - name: Report success or failure as github status
# if: always()
# shell: bash
# run: |
# status="${{ job.status }}"
# lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
# curl -sS --request POST \
# --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
# --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
# --header 'content-type: application/json' \
# --data '{
# "state": "'$lowercase_status'",
# "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
# "description": "'$status'",
# "context": "github-actions/build"
# }'