Skip to content

Commit e885c8a

Browse files
skyetensorflower-gardener
authored andcommitted
[jax] Expose Device.local_hardware_id
This is to help us implement the new __dlpack_device__ interface (https://dmlc.github.io/dlpack/latest/python_spec.html). I'm also ok making this a private method, but figured it makes sense as a public Device method. PiperOrigin-RevId: 557286157
1 parent 318f291 commit e885c8a

File tree

4 files changed

+22
-1
lines changed

4 files changed

+22
-1
lines changed

tensorflow/compiler/xla/python/xla.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,18 @@ PYBIND11_MODULE(xla_extension, m) {
201201
[](const ClientAndPtr<PjRtDevice>& device) {
202202
return device.client();
203203
})
204+
.def_property_readonly(
205+
"local_hardware_id",
206+
[](const ClientAndPtr<PjRtDevice>& device) -> std::optional<int> {
207+
int local_hardware_id = device->local_hardware_id();
208+
if (local_hardware_id == -1) {
209+
return std::nullopt;
210+
}
211+
return local_hardware_id;
212+
},
213+
"Opaque hardware ID, e.g., the CUDA device number. In general, not "
214+
"guaranteed to be dense, and not guaranteed to be defined on all "
215+
"platforms.")
204216
.def("__str__", &PjRtDevice::DebugString)
205217
.def("__repr__", &PjRtDevice::ToString)
206218
.def("transfer_to_infeed",

tensorflow/compiler/xla/python/xla_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
# Just an internal arbitrary increasing number to help with backward-compatible
4646
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
47-
_version = 183
47+
_version = 184
4848

4949
# Version number for MLIR:Python components.
5050
mlir_api_version = 54

tensorflow/compiler/xla/python/xla_client_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2215,6 +2215,12 @@ def testPlatform(self):
22152215
for device in self.backend.local_devices():
22162216
self.assertEqual(device.platform, self.backend.platform)
22172217

2218+
def testLocalHardwareId(self):
2219+
for device in self.backend.devices():
2220+
local_hardware_id = device.local_hardware_id
2221+
if local_hardware_id is not None:
2222+
self.assertGreaterEqual(local_hardware_id, 0)
2223+
22182224
@unittest.skipIf(pathways, "not implemented")
22192225
def testMemoryStats(self):
22202226
for device in self.backend.local_devices():

tensorflow/compiler/xla/python/xla_extension/__init__.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16+
from __future__ import annotations
17+
1618
import enum
1719
import inspect
1820
import types
@@ -350,6 +352,7 @@ class Device:
350352
platform: str
351353
device_kind: str
352354
client: Client
355+
local_hardware_id: int | None
353356
def __repr__(self) -> str: ...
354357
def __str__(self) -> str: ...
355358
def transfer_to_infeed(self, literal: _LiteralSlice): ...

0 commit comments

Comments
 (0)