forked from Lightning-AI/lightning-thunder
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdtensor_codeutils.py
More file actions
35 lines (29 loc) · 1.01 KB
/
dtensor_codeutils.py
File metadata and controls
35 lines (29 loc) · 1.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from typing import Any
from torch.distributed.tensor._dtensor_spec import DTensorSpec, DeviceMesh, TensorMeta
from torch.distributed.tensor import DeviceMesh, Partial, Placement, Replicate, Shard
def populate_object_ctx_for_dtensor_spec(x: Any, object_ctx: dict[str, Any]) -> bool:
"""
Populate object context for DTensorSpec.
..note::
This function will mutate the `object_ctx`
Returns:
bool: True if `x` is DTensorSpec (and also updates `object_ctx`) otherwise False.
"""
if isinstance(x, DTensorSpec):
object_ctx.update(
{
"DTensorSpec": DTensorSpec,
"DeviceMesh": DeviceMesh,
"Placement": Placement,
"Replicate": Replicate,
"Shard": Shard,
"Partial": Partial,
"TensorMeta": TensorMeta,
}
)
return True
return False
def prettyprint_dtensor_spec(x):
if isinstance(x, DTensorSpec):
return x.__repr__()
return ""