diff --git a/maxdiffusion_jax_ai_image_tpu.Dockerfile b/maxdiffusion_jax_ai_image_tpu.Dockerfile index 301f9b88..86969b09 100644 --- a/maxdiffusion_jax_ai_image_tpu.Dockerfile +++ b/maxdiffusion_jax_ai_image_tpu.Dockerfile @@ -3,6 +3,8 @@ ARG JAX_AI_IMAGE_BASEIMAGE # JAX AI Base Image FROM $JAX_AI_IMAGE_BASEIMAGE +ARG JAX_AI_IMAGE_BASEIMAGE + ARG COMMIT_HASH ENV COMMIT_HASH=$COMMIT_HASH @@ -18,5 +20,12 @@ COPY . . # Install Maxdiffusion Jax AI Image requirements RUN pip install -r /deps/requirements_with_jax_ai_image.txt +# TODO: Remove the flax pin and fsspec overrides once flax stable version releases +RUN if echo "$JAX_AI_IMAGE_BASEIMAGE" | grep -q "nightly"; then \ + echo "Nightly build detected: Installing specific Flax commit and fsspec." && \ + pip install --upgrade --force-reinstall git+https://github.com/google/flax.git@ef78d6584623511746be4824965cdef42b464583 && \ + pip install "fsspec==2025.10.0"; \ + fi + # Run the script available in JAX-AI-Image base image to generate the manifest file RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH \ No newline at end of file