Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
5873742
dtensor support
kshitij12345 Mar 21, 2025
377125a
add comment
kshitij12345 Mar 24, 2025
7ab82f6
add more comments
kshitij12345 Mar 24, 2025
e6aa8d3
update comment
kshitij12345 Mar 24, 2025
e76fc17
add test for execpted failing cases
kshitij12345 Mar 24, 2025
eaac9f7
support for method
kshitij12345 Mar 24, 2025
94ef69d
update failing case test
kshitij12345 Mar 24, 2025
5d81851
remove generated traces
kshitij12345 Mar 26, 2025
7277753
undo pre-commit change
kshitij12345 Mar 26, 2025
a8c58e4
undo debug changes
kshitij12345 Mar 26, 2025
d87b103
update failing test to use thunder.jit
kshitij12345 Mar 26, 2025
b101161
update registration helper
kshitij12345 Mar 26, 2025
b551cb8
Apply suggestions from code review
kshitij12345 Mar 31, 2025
1c75a80
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 Apr 1, 2025
5854c86
address review and upadte
kshitij12345 Apr 1, 2025
a778830
update dtensor proxy repr
kshitij12345 Apr 2, 2025
41990d0
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 Apr 2, 2025
eda0277
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 Apr 3, 2025
8abf040
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 Apr 9, 2025
225f2e3
update jit_ext access to torchfn_to_thunder registry : test
kshitij12345 Apr 9, 2025
2b85b31
empty commit
kshitij12345 Apr 9, 2025
5d0296f
Revert "update jit_ext access to torchfn_to_thunder registry : test"
kshitij12345 Apr 10, 2025
dedab03
temp commit
kshitij12345 Apr 15, 2025
efaae1d
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 Apr 24, 2025
ddcf208
update to manual decomp
kshitij12345 Apr 24, 2025
6a6bf11
add manual grad rule
kshitij12345 Apr 24, 2025
2a8ea02
update
kshitij12345 May 14, 2025
9490cac
update - clean-up
kshitij12345 May 14, 2025
bd1ecbb
update attrs on DTensorProxy
kshitij12345 May 14, 2025
5d1f20b
Merge branch 'main' of github.com:Lightning-AI/lightning-thunder into…
kshitij12345 May 14, 2025
255c82d
remove debug change
kshitij12345 May 15, 2025
dba02d2
remove unused imports
kshitij12345 May 15, 2025
e8f6d0b
remove unused import
kshitij12345 May 15, 2025
83ae80a
update function name
kshitij12345 May 15, 2025
4cd9bec
cotangent metadata check initial support
kshitij12345 May 16, 2025
b49df80
address review : p1
kshitij12345 May 16, 2025
b206632
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2025
b70238a
address review
kshitij12345 May 16, 2025
1ff76b3
Merge branch 'dtensor-init-support' of https://github.com/kshitij1234…
kshitij12345 May 16, 2025
3e295da
update and refactor
kshitij12345 May 16, 2025
8f4a029
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2025
a445053
update and add test
kshitij12345 Jun 6, 2025
530d29c
update test
kshitij12345 Jun 6, 2025
56dc94e
Merge branch 'main' of https://github.com/Lightning-AI/lightning-thun…
kshitij12345 Jun 6, 2025
e285e97
undo changes for thunderfx path, fix later
kshitij12345 Jun 6, 2025
284d109
Move import of is_dtensor_proxy to the top of the file
IvanYashchuk Jun 11, 2025
b26e8b7
Move import of check_dtensor_spec_repr to the top of the file
IvanYashchuk Jun 11, 2025
a7eed41
format dtensor imports
IvanYashchuk Jun 11, 2025
8bd1b99
Move import of handle_check_dtensor_spec_in_prologue to the top of th…
IvanYashchuk Jun 11, 2025
8cad766
Return only fake tensors from run_with_fake_tensor
IvanYashchuk Jun 11, 2025
f4dc375
Remove unused aot_function
IvanYashchuk Jun 11, 2025
c8459a2
Remove TracingContext, tracing usage
IvanYashchuk Jun 11, 2025
2cf4c39
Inline FakeTensorMode into _run_with_fake
IvanYashchuk Jun 11, 2025
1949011
Rename _run_with_fake->run_with_fake_tensor
IvanYashchuk Jun 11, 2025
fdfbb79
Merge branch 'main' of https://github.com/Lightning-AI/lightning-thun…
kshitij12345 Jun 12, 2025
031eba0
address review
kshitij12345 Jun 12, 2025
d600791
Merge branch 'dtensor-init-support' of https://github.com/kshitij1234…
kshitij12345 Jun 12, 2025
384b7c6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2025
8bc5228
update
kshitij12345 Jun 12, 2025
f688c98
Merge branch 'dtensor-init-support' of https://github.com/kshitij1234…
kshitij12345 Jun 12, 2025
27a4c26
xfail constant_folding test
kshitij12345 Jun 13, 2025
9291923
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
dtensor support
  • Loading branch information
kshitij12345 committed Mar 21, 2025
commit 587374271042b53baf3b6f92cf4c4a1d5168fc68
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,5 @@ repos:
# https://prettier.io/docs/en/options.html#print-width
files: \.(json|yml|yaml|toml)
args: ["--print-width=120"]

exclude: "generated_dtensor_trcs/.*"
37 changes: 37 additions & 0 deletions class_structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
class TensorMeta(NamedTuple):
# simple named tuple to represent tensor metadata
# intentionally to stay simple only for sharding
# propagation purposes.
shape: torch.Size
stride: Tuple[int, ...]
dtype: torch.dtype


class Placement:
def is_shard() -> bool:
pass

def is_replicate() -> bool:
pass

def is_partial() -> bool:
pass


class DeviceMesh:
device_type: str
mesh: torch.Tensor
mesh_dim_names: Optional[Tuple[str, ...]]


class DTensorSpec:
mesh: DeviceMesh
placements: Tuple[Placement, ...]

# tensor meta will only be set during sharding propagation
tensor_meta: Optional[TensorMeta] = None


class DTensor:
_local_tensor: torch.Tensor
_spec: DTensorSpec
32 changes: 32 additions & 0 deletions generated_dtensor_trcs/dtensor_bwd_exec_trc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, _, = saved_for_backward
# C0: "Collection"
# None
clear_mutable_collection(saved_for_backward)
del saved_for_backward
t2, = cotangents
# t2: "DTensor cuda:0 f32[16, 16]"
clear_mutable_collection(cotangents)
del cotangents
t21, = C0
# t21: "cuda:0 f32[8, 16]"
clear_mutable_collection(C0)
del C0
bw_t25 = get_dtensor_inner_tensor(t2) # bw_t25: "cuda:0 f32[8, 16]"
# bw_t25 = thunder.torch.experimental.dtensor_prims_and_impl.get_dtensor_inner_tensor(t2) # bw_t25: "cuda:0 f32[8, 16]"
del t2
[bw_t13] = nvFusion0(t21, bw_t25)
# bw_t13 = prims.mul(t21, bw_t25) # bw_t13: "cuda:0 f32[8, 16]"
del t21, bw_t25
bw_t27 = construct_dtensor(bw_t13, DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=(16, 16), stride=(16, 1), dtype=torch.float32))) # bw_t27: "DTensor cuda:0 f32[16, 16]"
# bw_t27 = thunder.torch.experimental.dtensor_prims_and_impl.construct_dtensor(bw_t13, DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=(16, 16), stride=(16, 1), dtype=torch.float32))) # bw_t27: "DTensor cuda:0 f32[16, 16]"
del bw_t13
return (bw_t27, None)
29 changes: 29 additions & 0 deletions generated_dtensor_trcs/dtensor_bwd_init_trc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Constructed by Backward pass
import thunder
import thunder.torch as ltorch
import thunder.torch.experimental.dtensor_prims_and_impl
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, C1, = saved_for_backward
# C0: "Collection"
# C1: "Collection"
t2, = cotangents
# t2: "DTensor cuda:0 f32[16, 16]"
t20, t21, = C0
# t20: "cuda:0 f32[8, 16]"
# t21: "cuda:0 f32[8, 16]"
# C1 (empty sequence)
t11 = thunder.torch.experimental.dtensor_prims_and_impl.get_dtensor_inner_tensor(t2) # t11: "cuda:0 f32[8, 16]"
t13 = ltorch.mul(t21, t11) # t13: "cuda:0 f32[8, 16]"
# t13 = prims.mul(t21, t11) # t13: "cuda:0 f32[8, 16]"
t14 = ltorch.mul(t20, t11) # t14: "cuda:0 f32[8, 16]"
# t14 = prims.mul(t20, t11) # t14: "cuda:0 f32[8, 16]"
t16 = thunder.torch.experimental.dtensor_prims_and_impl.construct_dtensor(t14, DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=(16, 16), stride=(16, 1), dtype=torch.float32))) # t16: "DTensor cuda:0 f32[16, 16]"
t18 = thunder.torch.experimental.dtensor_prims_and_impl.construct_dtensor(t13, DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=(16, 16), stride=(16, 1), dtype=torch.float32))) # t18: "DTensor cuda:0 f32[16, 16]"
return (t18, None)
24 changes: 24 additions & 0 deletions generated_dtensor_trcs/dtensor_exec_trc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(l_x_, l_w_):
# l_x_: "DTensor cuda:0 f32[16, 16]"
# l_w_: "DTensor cuda:0 f32[16, 16]"

# <eval_with_key>.10:5: mul = torch.mul(l_x_, l_w_); l_x_ = l_w_ = None
t20 = get_dtensor_inner_tensor(l_x_) # t20: "cuda:0 f32[8, 16]"
# t20 = thunder.torch.experimental.dtensor_prims_and_impl.get_dtensor_inner_tensor(l_x_) # t20: "cuda:0 f32[8, 16]"
t21 = get_dtensor_inner_tensor(l_w_) # t21: "cuda:0 f32[8, 16]"
# t21 = thunder.torch.experimental.dtensor_prims_and_impl.get_dtensor_inner_tensor(l_w_) # t21: "cuda:0 f32[8, 16]"
[t10] = nvFusion0(t20, t21)
# t10 = prims.mul(t20, t21) # t10: "cuda:0 f32[8, 16]"
del t20

# <eval_with_key>.10:5: mul = torch.mul(l_x_, l_w_); l_x_ = l_w_ = None
t23 = construct_dtensor(t10, DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=(16, 16), stride=(16, 1), dtype=torch.float32))) # t23: "DTensor cuda:0 f32[16, 16]"
# t23 = thunder.torch.experimental.dtensor_prims_and_impl.construct_dtensor(t10, DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=(16, 16), stride=(16, 1), dtype=torch.float32))) # t23: "DTensor cuda:0 f32[16, 16]"
del t10
return {'output': (t23,), 'flat_args': [l_x_, l_w_], 'flat_output': (t23,)}, ((t21,), ())
19 changes: 19 additions & 0 deletions generated_dtensor_trcs/dtensor_init_trc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import thunder
import thunder.torch.experimental.dtensor_torch_and_aten_ops
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(l_x_, l_w_):
# l_x_: "DTensor cuda:0 f32[16, 16]"
# l_w_: "DTensor cuda:0 f32[16, 16]"

# <eval_with_key>.10:5: mul = torch.mul(l_x_, l_w_); l_x_ = l_w_ = None
mul = thunder.torch.experimental.dtensor_torch_and_aten_ops.dtensor_mul(l_x_, l_w_) # mul: "DTensor cuda:0 f32[16, 16]"
# t4 = thunder.torch.experimental.dtensor_prims_and_impl.get_dtensor_inner_tensor(l_x_) # t4: "cuda:0 f32[8, 16]"
# t5 = thunder.torch.experimental.dtensor_prims_and_impl.get_dtensor_inner_tensor(l_w_) # t5: "cuda:0 f32[8, 16]"
# t0 = thunder.torch.experimental.dtensor_torch_and_aten_ops.aten_mul(t4, t5) # t0: "cuda:0 f32[8, 16]"
# t0 = prims.mul(t4, t5) # t0: "cuda:0 f32[8, 16]"
# mul = thunder.torch.experimental.dtensor_prims_and_impl.construct_dtensor(t0, DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=(16, 16), stride=(16, 1), dtype=torch.float32))) # mul: "DTensor cuda:0 f32[16, 16]"
return (mul,)
39 changes: 39 additions & 0 deletions generated_dtensor_trcs/dtensor_pro_trc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import thunder
import thunder.core.prims as prims
import thunder.torch.experimental.dtensor_prims_and_impl
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
# args: "Any"
prims.check_len(args, 2)
# kwargs: "Any"
prims.check_len(kwargs, 0)
l_x_: "DTensor cuda:0 f32[16, 16]" = args[0]
l_w_: "DTensor cuda:0 f32[16, 16]" = args[1]
dtensor_spec0: "<class 'NoneType'>" = l_x_._spec
thunder.torch.experimental.dtensor_prims_and_impl.check_dtensor_spec_repr(dtensor_spec0, "DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=torch.Size([16, 16]), stride=(16, 1), dtype=torch.float32))")
t1: "cuda:0 f32[8, 16]" = l_x_._local_tensor
prims.check_tensor_shape_and_metadata(t1, (8, 16), 'cuda:0', torch.float32, True)
prims.check_tensor_shape_and_metadata(l_x_, (16, 16), 'cuda:0', torch.float32, True)
dtensor_spec2: "<class 'NoneType'>" = l_w_._spec
thunder.torch.experimental.dtensor_prims_and_impl.check_dtensor_spec_repr(dtensor_spec2, "DTensorSpec(mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=torch.Size([16, 16]), stride=(16, 1), dtype=torch.float32))")
t3: "cuda:0 f32[8, 16]" = l_w_._local_tensor
prims.check_tensor_shape_and_metadata(t3, (8, 16), 'cuda:0', torch.float32, False)
prims.check_tensor_shape_and_metadata(l_w_, (16, 16), 'cuda:0', torch.float32, False)
cache_info: "Any" = thunder._get_cache_info()
cache_info_default_dtype: "<class 'torch.dtype'>" = cache_info['default_dtype']
prims.check_literal_like(cache_info_default_dtype, torch.float32)
cache_info_default_device: "<class 'torch.device'>" = cache_info['default_device']
prims.check_literal_like(cache_info_default_device, torch.device("cpu"))
cache_info_is_autocast_enabled: "bool False" = cache_info['is_autocast_enabled']
prims.check_number_type_and_value(cache_info_is_autocast_enabled, False)
cache_info_alias_tensor_indices: "str" = cache_info['alias_tensor_indices']
prims.check_string_value(cache_info_alias_tensor_indices, '')
cache_info_is_grad_enabled: "bool True" = cache_info['is_grad_enabled']
prims.check_number_type_and_value(cache_info_is_grad_enabled, True)
cache_info_no_grad_sync: "bool False" = cache_info['no_grad_sync']
prims.check_number_type_and_value(cache_info_no_grad_sync, False)
return ((l_x_, l_w_), ())
69 changes: 69 additions & 0 deletions test_dtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# torchrun --nnodes 1 --nproc-per-node 2 test_dtensor.py
import torch.nn as nn
import torch
from torch.distributed._tensor import DeviceMesh, Shard, distribute_tensor
from torch.distributed.tensor.placement_types import Placement, Shard, Replicate
import os
from thunder.dynamo import thunderfx
import torch.distributed as dist

LOCAL_RANK = int(os.environ["LOCAL_RANK"])
num_devices = 2
mesh = DeviceMesh("cuda", list(range(num_devices)))

hidden_size = 16


def model(x, w):
# return torch.nn.functional.linear(x, w)
# return torch.add(x, w)
return torch.mul(x, w)


weight = distribute_tensor(torch.randn(hidden_size, hidden_size, requires_grad=False), mesh, [Shard(0)])
bias = distribute_tensor(torch.randn(hidden_size, requires_grad=False), mesh, [Shard(0)])

in_dtensor = distribute_tensor(torch.randn(hidden_size, hidden_size, requires_grad=True), mesh, [Shard(0)])

expected = torch.compile(model)(in_dtensor, weight)
tmodel = thunderfx(model)
actual = tmodel(in_dtensor, weight)

# def model(x):
# return x + 1

# in_tensor = torch.randn(num_devices, 4)
# mesh = dist.device_mesh.init_device_mesh("cuda", [num_devices])
# in_dtensor = dist.tensor.distribute_tensor(in_tensor, mesh, [Shard(0)])

# print(in_dtensor.shape)

# expected = torch.compile(model)(in_dtensor)
# actual = thunderfx(model, nv_enable_matmul=True, nv_enable_linear=True)(in_dtensor)

torch.testing.assert_close(actual.to_local(), expected.to_local())

g_o = distribute_tensor(torch.ones(hidden_size, hidden_size), mesh, [Shard(0)])
expected_g = torch.autograd.grad(
expected,
(in_dtensor,),
g_o,
)
actual_g = torch.autograd.grad(actual, (in_dtensor,), g_o)

torch.testing.assert_close(actual_g, expected_g)

if LOCAL_RANK == 0:
import thunder

thunder_fn = tmodel._backend.subgraph_infos[0].thunder_compiled_fns[0]
traces = thunder.last_traces(thunder_fn)
traces[0].save_trace("generated_dtensor_trcs/dtensor_init_trc.py")
traces[-1].save_trace("generated_dtensor_trcs/dtensor_exec_trc.py")

pro_traces = thunder.last_prologue_traces(thunder_fn)
pro_traces[0].save_trace("generated_dtensor_trcs/dtensor_pro_trc.py")

bwd_traces = thunder.last_backward_traces(thunder_fn)
bwd_traces[0].save_trace("generated_dtensor_trcs/dtensor_bwd_init_trc.py")
bwd_traces[-1].save_trace("generated_dtensor_trcs/dtensor_bwd_exec_trc.py")
8 changes: 7 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,13 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]:
data_ptr_to_tensor_group_index = {}
tensor_group_index_to_tensor_indices = defaultdict(list)
for idx, t in enumerate(flat_args):
if pytorch.is_tensor(t) and t.layout == pytorch.strided:
# Using type(t) is pytorch.Tensor as TensorSubclasses don't support calling
# data_ptr().
# Eg. RuntimeError: Attempted to access the data pointer on an invalid python storage. (data_ptr access on TensorSubclass)
#
# isinstance(t, pytorch.Tensor) or pytorch.is_tensor(t) will match all Tensor objects including
# subclasses.
if type(t) is pytorch.Tensor and t.layout == pytorch.strided:
data_ptr = t.untyped_storage().data_ptr()
if data_ptr not in data_ptr_to_tensor_group_index:
data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index)
Expand Down
10 changes: 10 additions & 0 deletions thunder/core/codeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ def to_printable(
if isinstance(x, ProxyInterface):
return x

from thunder.torch.experimental.dtensor_codeutils import populate_object_ctx_for_dtensor_spec

if populate_object_ctx_for_dtensor_spec(x, object_ctx):
return x

if dataclasses.is_dataclass(x):
# Add `class` to the object_ctx so that we can reuse it during the trace execution.
if isinstance(x, type): # dataclass type
Expand Down Expand Up @@ -236,6 +241,11 @@ def prettyprint(
if isinstance(x, ContextObject):
return m(x.name)

from thunder.torch.experimental.dtensor_codeutils import prettyprint_dtensor_spec

if (dtensor_repr := prettyprint_dtensor_spec(x)) != "":
return m(dtensor_repr)

if dataclasses.is_dataclass(x):
# For a dataclass instance of class
# class MyContainer:
Expand Down
16 changes: 16 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@
from thunder.torch import _torch_to_thunder_function_map
from thunder.clang import _clang_fn_set
from thunder.core.pytree import tree_map, tree_iter
from thunder.torch.experimental.dtensor_torch_and_aten_ops import register_dtensor_and_aten_function

# TODO: Find a better place to register these ops (mostly in thunder/torch/__init__.py but without cyclical dependency).
register_dtensor_and_aten_function()

#
# jit_ext.py implements extensions of thunder's interpreter
Expand Down Expand Up @@ -259,9 +263,16 @@ def proxify(self, value: WrappedValue) -> Any:

if p is not uvalue:
value.register_proxy(p)

from thunder.torch.experimental.dtensor_proxy import is_dtensor_proxy
from thunder.torch.experimental import dtensor_prims_and_impl

# TODO: other caching modes
co: CACHE_OPTIONS = get_cache_option()
if co is CACHE_OPTIONS.CONSTANT_VALUES:
if is_dtensor_proxy(p):
# Add check for mesh and layout.
self.add_constraint((dtensor_prims_and_impl.check_dtensor_spec_repr, p, uvalue._spec))
self.add_constraint((clang.check_tensor_shape_and_metadata, p))
elif co is CACHE_OPTIONS.SYMBOLIC_VALUES:
# TODO: establish guarding logic to allow non-broadcast shape change
Expand Down Expand Up @@ -1802,6 +1813,11 @@ def is_variableified_tensorproxy(v: Variable | Proxy) -> Proxy:
if isinstance(s, Proxy):
unpack(s)

from thunder.torch.experimental.dtensor_prims_and_impl import handle_check_dtensor_spec_in_prologue

if handle_check_dtensor_spec_in_prologue(prim, prologue_trace, args):
continue

prim(*args)

cache_info = thunder._get_cache_info()
Expand Down
3 changes: 2 additions & 1 deletion thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -1830,7 +1830,8 @@ def _get_grad_meta(a: Number | NumberProxy | TensorProxy, /) -> Number | TensorP
utils.check_type(a, (Number, NumberProxy, TensorProxy))

if isinstance(a, TensorProxy):
return TensorProxy(like=a)
# NOTE: `a` could be a TensorProxy subclass and it's type should be preserved.
return type(a)(like=a)

# NOTE a is a Number in this branch
return numberproxy(pytype(a), 0)
Expand Down
6 changes: 6 additions & 0 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2050,6 +2050,12 @@ def proxy(x: Any, *, name: str | None = None, history: None | tuple = None) -> A
if x is ...:
return AnyProxy(x, name=name, history=history)

# Import here to avoid cyclical dependency.
from thunder.torch.experimental.dtensor_proxy import proxify_dtensor

if (dtensor_proxy := proxify_dtensor(x, name, history)) is not None:
return dtensor_proxy

if isinstance(x, torch.Tensor):
return tensorproxy(x, name=name, history=history)

Expand Down
3 changes: 2 additions & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3095,7 +3095,8 @@ def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> Forwa

def ones_like(x):
if isinstance(x, TensorProxy):
return full_like(x, fill_value=1)
# NOTE: x could be a subclass of TensorProxy and that should be preserved.
return type(x)(like=x)
elif isinstance(x, NumberProxy):
return type(x.value)(1)
else:
Expand Down
Loading