Fix libtpu version for torch and do not pre-install tensorflow-tpu on TPU. #1499
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
We install a libtpu version compatible with both jax 0.7.2 and torch 2.8.0. Why? tunix latest -> flax 0.12 -> jax 0.7.2 -> libtpu 0.0.23, and that libtpu version causes pjrt api errors for torch 2.8.0:
Of particular note, we no longer pre-install
tensorflow-tpu
as the newer libtpu causes issues finding the TPUsWe also update how we install Python packages via
uv
for consistency and reproducibility. From arequirements.in
file, we first generate a consistent dependency closure viauv pip compile
, and thenuv pip install
the packages from the generatedrequirements.txt
.