Skip to content

Conversation

metrizable
Copy link
Contributor

@metrizable metrizable commented Sep 26, 2025

We install a libtpu version compatible with both jax 0.7.2 and torch 2.8.0. Why? tunix latest -> flax 0.12 -> jax 0.7.2 -> libtpu 0.0.23, and that libtpu version causes pjrt api errors for torch 2.8.0:

pjrt_c_api_helpers.cc:258] Unexpected error status Unexpected PJRT_Plugin_Attributes_Args size: expe
cted 32, got 24. The plugin is likely built with a later version than the framework. This plugin is built with PJRT API version 0.75.

Of particular note, we no longer pre-install tensorflow-tpu as the newer libtpu causes issues finding the TPUs

external/local_xla/xla/stream_executor/tpu/tpu_platform_interface.cc:78] No TPU platform found. Platform manager status: OK

We also update how we install Python packages via uv for consistency and reproducibility. From a requirements.in file, we first generate a consistent dependency closure via uv pip compile, and then uv pip install the packages from the generated requirements.txt.

Copy link
Contributor

@calderjo calderjo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome thanks!

@calderjo calderjo merged commit 3e031ba into Kaggle:main Sep 26, 2025
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants