Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions tpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,39 @@ RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
# Additional useful packages should be added in the requirements.txt
# Bring in the requirements.txt and replace variables in it:
RUN apt-get install -y gettext
ADD tpu/requirements.txt /kaggle_requirements.txt
RUN envsubst < /kaggle_requirements.txt > /requirements.txt
ADD tpu/requirements.in /kaggle_requirements.in
RUN envsubst < /kaggle_requirements.in > /requirements.in

# Install uv and then install the requirements:
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
RUN export PATH="${HOME}/.local/bin:${PATH}" && uv pip install --system -r /requirements.txt --prerelease=allow --force-reinstall && \
RUN export PATH="${HOME}/.local/bin:${PATH}" && \
uv pip compile --system --prerelease=allow \
--verbose \
--upgrade \
--find-links=https://storage.googleapis.com/jax-releases/libtpu_releases.html \
--find-links=https://storage.googleapis.com/libtpu-releases/index.html \
--find-links=https://storage.googleapis.com/libtpu-wheels/index.html \
--find-links=https://download.pytorch.org/whl/torch_stable.html \
--emit-find-links \
--no-emit-package pip \
--no-emit-package setuptools \
--output-file /requirements.txt \
/requirements.in && \
uv pip install --system --prerelease=allow --force-reinstall \
-r /requirements.txt && \
uv cache clean && \
/tmp/clean-layer.sh
ENV PATH="~/.local/bin:${PATH}"

# Try to force tensorflow to reliably install without breaking other installed deps
# We install a libtpu version compatible with both jax 0.7.2 and torch 2.8.0.
# Why? tunix latest -> flax 0.12 -> jax 0.7.2 -> libtpu 0.0.23. However, that
# libtpu causes pjrt api errors for torch 2.8.0. screenshot/5heUtdyaJ4MmR3D
# https://github.com/pytorch/xla/blob/d517649bdef6ab0519c30c704bde8779c8216502/setup.py#L111
# https://github.com/jax-ml/jax/blob/3489529b38d1f11d1e5caf4540775aadd5f2cdda/setup.py#L26
RUN export PATH="${HOME}/.local/bin:${PATH}" && \
uv pip freeze --system > /tmp/constraints.txt && \
uv pip install --system -c /tmp/constraints.txt tensorflow-tpu -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force-reinstall && \
rm /tmp/constraints.txt
uv pip install --system --force-reinstall libtpu==0.0.17 && \
uv cache clean && \
/tmp/clean-layer.sh

# Kaggle Model Hub patches:
ADD patches/kaggle_module_resolver.py /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow_hub/kaggle_module_resolver.py
Expand Down
6 changes: 2 additions & 4 deletions tpu/requirements.txt → tpu/requirements.in
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# TPU Utils
tpu-info
# Tensorflow packages
tensorflow-tpu==${TENSORFLOW_VERSION}
--find-links https://storage.googleapis.com/libtpu-tf-releases/index.html
tensorflow-cpu==${TENSORFLOW_VERSION}
tensorflow_hub
tensorflow-io
tensorflow-probability
Expand All @@ -13,8 +12,7 @@ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TOR
torchaudio==${TORCHAUDIO_VERSION}
torchvision==${TORCHVISION_VERSION}
# Jax packages
jax[tpu]>=0.5.2
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
jax[tpu]
distrax
flax
git+https://github.com/deepmind/dm-haiku
Expand Down