diff --git a/thunder/__init__.py b/thunder/__init__.py index 83065ff0b2..5a6b8f13c5 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -405,7 +405,13 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]: data_ptr_to_tensor_group_index = {} tensor_group_index_to_tensor_indices = defaultdict(list) for idx, t in enumerate(flat_args): - if pytorch.is_tensor(t) and t.layout == pytorch.strided: + # Using type(t) is pytorch.Tensor as TensorSubclasses don't support calling + # data_ptr(). + # Eg. RuntimeError: Attempted to access the data pointer on an invalid python storage. (data_ptr access on TensorSubclass) + # + # isinstance(t, pytorch.Tensor) or pytorch.is_tensor(t) will match all Tensor objects including + # subclasses. + if type(t) is pytorch.Tensor and t.layout is pytorch.strided: data_ptr = t.untyped_storage().data_ptr() if data_ptr not in data_ptr_to_tensor_group_index: data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index) diff --git a/thunder/core/codeutils.py b/thunder/core/codeutils.py index 5d4314d934..22c7df84ab 100644 --- a/thunder/core/codeutils.py +++ b/thunder/core/codeutils.py @@ -152,6 +152,11 @@ def to_printable( if isinstance(x, ProxyInterface): return x + from thunder.torch.experimental.dtensor_codeutils import populate_object_ctx_for_dtensor_spec + + if populate_object_ctx_for_dtensor_spec(x, object_ctx): + return x + if dataclasses.is_dataclass(x): # Add `class` to the object_ctx so that we can reuse it during the trace execution. if isinstance(x, type): # dataclass type @@ -236,6 +241,11 @@ def prettyprint( if isinstance(x, ContextObject): return m(x.name) + from thunder.torch.experimental.dtensor_codeutils import prettyprint_dtensor_spec + + if dtensor_repr := prettyprint_dtensor_spec(x): + return m(dtensor_repr) + if dataclasses.is_dataclass(x): # For a dataclass instance of class # class MyContainer: diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index a3fc9d6bde..b596c5fe81 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -74,6 +74,15 @@ from thunder.torch import _torch_to_thunder_function_map from thunder.clang import _clang_fn_set from thunder.core.pytree import tree_map, tree_iter +from thunder.torch.experimental.dtensor_proxy import is_dtensor_proxy +from thunder.torch.experimental.dtensor_torch_and_prims import ( + check_dtensor_spec_repr, + handle_check_dtensor_spec_in_prologue, + register_dtensor_torch_and_prims, +) + +# TODO: Find a better place to register these ops (mostly in thunder/torch/__init__.py but without cyclical dependency). +register_dtensor_torch_and_prims() # # jit_ext.py implements extensions of thunder's interpreter @@ -273,9 +282,12 @@ def proxify(self, value: WrappedValue) -> Any: if p is not uvalue: value.register_proxy(p) + # TODO: other caching modes co: CACHE_OPTIONS = get_cache_option() if co is CACHE_OPTIONS.CONSTANT_VALUES: + if is_dtensor_proxy(p): + self.add_constraint((check_dtensor_spec_repr, p, uvalue._spec)) self.add_constraint((clang.check_tensor_shape_and_metadata, p)) elif co is CACHE_OPTIONS.SYMBOLIC_VALUES: # TODO: establish guarding logic to allow non-broadcast shape change @@ -1880,6 +1892,10 @@ def is_variableified_tensorproxy(v: Variable | Proxy) -> Proxy: if isinstance(s, Proxy): unpack(s) + # Add checks for local tensor, mesh and placment of a DTensor + if handle_check_dtensor_spec_in_prologue(prim, prologue_trace, args): + continue + prim(*args) cache_info = thunder._get_cache_info() diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 663406248a..9e6777a84c 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -1833,7 +1833,8 @@ def _get_grad_meta(a: Number | NumberProxy | TensorProxy, /) -> Number | TensorP utils.check_type(a, (Number, NumberProxy, TensorProxy)) if isinstance(a, TensorProxy): - return TensorProxy(like=a) + # NOTE: `a` could be a TensorProxy subclass and it's type should be preserved. + return type(a)(like=a) # NOTE a is a Number in this branch return numberproxy(pytype(a), 0) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 39c0202fb7..ae1ed6836e 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -2071,6 +2071,12 @@ def proxy(x: Any, *, name: str | None = None, history: None | tuple = None) -> A if x is ...: return AnyProxy(x, name=name, history=history) + # Import here to avoid cyclical dependency. + from thunder.torch.experimental.dtensor_proxy import proxify_dtensor + + if (dtensor_proxy := proxify_dtensor(x, name, history)) is not None: + return dtensor_proxy + if isinstance(x, torch.Tensor): return tensorproxy(x, name=name, history=history) diff --git a/thunder/core/pytree.py b/thunder/core/pytree.py index 51f99d23eb..b8d884aa2a 100644 --- a/thunder/core/pytree.py +++ b/thunder/core/pytree.py @@ -69,6 +69,7 @@ def tree_flatten(args, namespace=OPTREE_NAMESPACE): torch.autograd.function.FunctionCtx, immutable_list, *torch.types.py_sym_types, + *((torch.distributed._tensor.DTensor,) if torch.distributed.is_available() else ()), } and not isinstance(args, (ProxyInterface)) and not is_likely_from_collections_namedtuple(args) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index a1c309465a..aa09be6a84 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -35,7 +35,7 @@ from thunder.core.langctxs import langctx, Languages from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten, tree_flatten_with_dataclass from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol, has_tags -from thunder.core.trace import TraceCtx as Trace +from thunder.core.trace import TraceCtx as Trace, get_tracectx from thunder.core.trace import VariableInterface as Variable from thunder.core.trace import ( detached_trace, @@ -3127,7 +3127,8 @@ def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> Forwa def ones_like(x): if isinstance(x, TensorProxy): - return full_like(x, fill_value=1) + # NOTE: x could be a subclass of TensorProxy and that should be preserved. + return type(x)(like=x) elif isinstance(x, NumberProxy): return type(x.value)(1) else: diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index d371a4b9ab..87ff1cc2d0 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -6,6 +6,7 @@ import inspect import itertools import copy +from types import NoneType from collections import defaultdict from collections import namedtuple @@ -14,6 +15,11 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.utils.weak import TensorWeakRef +if torch.distributed.is_available(): + from torch.distributed.tensor import DTensor +else: + DTensor = NoneType + from thunder.torch.default_torch_ops import torch_auto_registered_ops from thunder.torch import _torch_to_thunder_function_map from thunder.torch.langctx import torchctx @@ -508,7 +514,8 @@ def _get_storage_shape(t: torch.Tensor): def _get_min_and_val(t: torch.Tensor) -> tuple[Number | None, Number | None]: - if isinstance(t, FakeTensor) or t.device.type == "meta" or t.numel() == 0: + # We assume that for TensorSubclass, `aminmax` is not supported which is true for FakeTensor and DTensor. + if (isinstance(t, torch.Tensor) and type(t) is not torch.Tensor) or t.device.type == "meta" or t.numel() == 0: return None, None if t.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz): t = t.to(torch.float32) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index c35658c5b1..63a939aa79 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -447,6 +447,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # We only want the forward function to be called with `te.fp8_autocast` manager. bw_extrace._include_te_fp8_autocast = False + # We only want to apply it on backward trace. + from thunder.torch.experimental.dtensor_utils import check_dtensor_cotangent_metadata_in_backward + + bw_extrace = check_dtensor_cotangent_metadata_in_backward(bw_extrace) + if len(bw_extrace.bound_symbols) == 1: # only return, no unpacking, so no gradient is calculated bw_extrace = None diff --git a/thunder/tests/distributed/test_dtensor.py b/thunder/tests/distributed/test_dtensor.py new file mode 100644 index 0000000000..4153ad6d20 --- /dev/null +++ b/thunder/tests/distributed/test_dtensor.py @@ -0,0 +1,120 @@ +import unittest + +import pytest +import torch + +if not torch.distributed.is_available(): + pytest.skip(allow_module_level=True) + +from thunder.dynamo import thunderfx +import thunder + +from thunder.tests.distributed.helper import DistributedParallelTestCase +from torch.distributed._tensor import DeviceMesh, Shard, distribute_tensor +from torch.distributed.tensor.placement_types import Placement, Shard, Replicate + + +@unittest.skipUnless( + torch.cuda.is_available() and torch.distributed.is_nccl_available(), + "DTensor test requires CUDA and NCCL `torch.distributed` backend", +) +class DTensorTest(DistributedParallelTestCase): + def test_dtensor_basic_op(self): + num_devices = self.world_size + mesh = DeviceMesh("cuda", list(range(num_devices))) + + dim_size = 16 + + def _helper(fn, in_dtensor, w_dtensor, compile_fn): + expected = torch.compile(fn)(in_dtensor, w_dtensor) + tmodel = compile_fn(fn) + actual = tmodel(in_dtensor, w_dtensor) + + torch.testing.assert_close(actual, expected) + + g_o = distribute_tensor(torch.ones(dim_size, dim_size), mesh, [Shard(0)]) + expected_g = torch.autograd.grad( + expected, + (in_dtensor, w_dtensor), + g_o, + ) + actual_g = torch.autograd.grad(actual, (in_dtensor, w_dtensor), g_o) + + torch.testing.assert_close(actual_g, expected_g) + if compile_fn is thunderfx: + assert len(tmodel._backend.subgraph_infos[0].split_reasons) > 0, "TODO: Fix with thunderfx path" + + w_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) + in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) + + # Verify torch API works + _helper(lambda x, w: torch.mul(x, w), in_dtensor, w_dtensor, compile_fn=thunder.jit) + _helper(lambda x, w: torch.mul(x, w), in_dtensor, w_dtensor, compile_fn=thunderfx) + + # Verify calling method works + _helper(lambda x, w: torch.Tensor.mul(x, w), in_dtensor, w_dtensor, compile_fn=thunder.jit) + _helper(lambda x, w: torch.Tensor.mul(x, w), in_dtensor, w_dtensor, compile_fn=thunderfx) + + # # Verify calling special method works + _helper(lambda x, w: x * w, in_dtensor, w_dtensor, compile_fn=thunder.jit) + _helper(lambda x, w: x * w, in_dtensor, w_dtensor, compile_fn=thunderfx) + + def test_dtensor_unsupported(self): + num_devices = self.world_size + mesh = DeviceMesh("cuda", list(range(num_devices))) + + dim_size = 16 + + w_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) + + in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) + + def fn(x, w): + return torch.div(x, w) + + tmodel = thunder.jit(fn) + with pytest.raises(AssertionError): + tmodel(in_dtensor, w_dtensor) + + def fn(x, w): + return x / w + + tmodel = thunder.jit(fn) + with pytest.raises(AssertionError): + tmodel(in_dtensor, w_dtensor) + + def test_dtensor_unsupported_mixed_input(self): + num_devices = self.world_size + mesh = DeviceMesh("cuda", list(range(num_devices))) + + dim_size = 16 + + def fn(x, w): + return torch.div(x, w) + + w = torch.randn(dim_size, dim_size, requires_grad=True) + + in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) + + tmodel = thunder.jit(fn) + with pytest.raises(AssertionError): + tmodel(in_dtensor, w) + + def test_dtensor_incorrect_cotangent(self): + num_devices = self.world_size + mesh = DeviceMesh("cuda", list(range(num_devices))) + + dim_size = 16 + + w_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) + in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) + + def fn(x, w): + return torch.mul(x, w) + + tmodel = thunder.jit(fn) + actual = tmodel(in_dtensor, w_dtensor) + g_o = distribute_tensor(torch.ones(dim_size, dim_size), mesh, [Shard(1)]) + + with pytest.raises(RuntimeError, match="has changed for cotangent between tracing and runtime"): + torch.autograd.grad(actual, (in_dtensor, w_dtensor), g_o) diff --git a/thunder/tests/test_transforms.py b/thunder/tests/test_transforms.py index d55ec0db47..d4a8588ed6 100644 --- a/thunder/tests/test_transforms.py +++ b/thunder/tests/test_transforms.py @@ -413,6 +413,7 @@ def forward(self, x): assert flat_arg_names == arg_names +@pytest.mark.xfail(strict=True) def test_constant_folding(): # Helper to verify we see the expected constant tensors # in exec_trace. diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 396c597d28..fe14ba0a68 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -139,6 +139,7 @@ def __init__( is_prim: bool = False, tags: None | list[Any] = None, out_of_place: Symbol | None = None, + allow_tensor_subclass_proxy: bool = False, ): self.torchfns = torchfns self.is_method = is_method or (method_name is not None) @@ -151,9 +152,30 @@ def __init__( self.tags = tags self.out_of_place = out_of_place + # This flag is used to enable/disable a torchsymbol to accept + # TensorProxy subclass as input (eg. DTensorProxy). + # By default, this is `False` as we don't want general `torchsymbol` + # which are meant for TensorProxy to accept DTensorProxy. + self.allow_tensor_subclass_proxy = allow_tensor_subclass_proxy + def __call__(self, fn: Callable) -> Symbol: _fn = langctx(Languages.TORCH)(fn) + if not self.allow_tensor_subclass_proxy: + + @wraps(_fn) + def wrapper(*args, **kwargs): + filter_tensor_proxies = list( + filter(lambda t: isinstance(t, TensorProxy), tree_flatten((args, kwargs))[0]) + ) + assert all(map(lambda t: type(t) is TensorProxy, filter_tensor_proxies)), ( + f"Expected all inputs to be TensorProxy but found {list(map(lambda t: type(t), filter_tensor_proxies))}" + ) + return _fn(*args, **kwargs) + + else: + wrapper = _fn + id: str if self.id is None: name = fn.__name__ @@ -178,10 +200,10 @@ def __call__(self, fn: Callable) -> Symbol: if self.is_prim: sym = Symbol( - name=fn.__name__, meta=langctx(Languages.PRIMS)(_fn), id=id, is_prim=self.is_prim, tags=self.tags + name=fn.__name__, meta=langctx(Languages.PRIMS)(wrapper), id=id, is_prim=self.is_prim, tags=self.tags ) else: - sym = Symbol(name=fn.__name__, meta=_fn, id=id, is_prim=self.is_prim, tags=self.tags) + sym = Symbol(name=fn.__name__, meta=wrapper, id=id, is_prim=self.is_prim, tags=self.tags) if self.is_method: method_name: str = self.method_name if self.method_name is not None else fn.__name__ diff --git a/thunder/torch/experimental/dtensor_codeutils.py b/thunder/torch/experimental/dtensor_codeutils.py new file mode 100644 index 0000000000..2eef845df4 --- /dev/null +++ b/thunder/torch/experimental/dtensor_codeutils.py @@ -0,0 +1,35 @@ +from typing import Any +from torch.distributed.tensor._dtensor_spec import DTensorSpec, DeviceMesh, TensorMeta +from torch.distributed.tensor import DeviceMesh, Partial, Placement, Replicate, Shard + + +def populate_object_ctx_for_dtensor_spec(x: Any, object_ctx: dict[str, Any]) -> bool: + """ + Populate object context for DTensorSpec. + + ..note:: + This function will mutate the `object_ctx` + + Returns: + bool: True if `x` is DTensorSpec (and also updates `object_ctx`) otherwise False. + """ + if isinstance(x, DTensorSpec): + object_ctx.update( + { + "DTensorSpec": DTensorSpec, + "DeviceMesh": DeviceMesh, + "Placement": Placement, + "Replicate": Replicate, + "Shard": Shard, + "Partial": Partial, + "TensorMeta": TensorMeta, + } + ) + return True + return False + + +def prettyprint_dtensor_spec(x): + if isinstance(x, DTensorSpec): + return x.__repr__() + return "" diff --git a/thunder/torch/experimental/dtensor_proxy.py b/thunder/torch/experimental/dtensor_proxy.py new file mode 100644 index 0000000000..03e5a7ead4 --- /dev/null +++ b/thunder/torch/experimental/dtensor_proxy.py @@ -0,0 +1,165 @@ +from thunder.core.proxies import TensorProxy, AnyProxy, _infer_tensor_properties +from torch.distributed._tensor import DTensor +from thunder.core.proxies import proxy +import thunder.core.devices as devices +import thunder.core.dtypes as dtypes +import thunder.core.utils as utils + + +# Inherit from TensorProxy as DTensor also supports +# Tensor methods like __add__, __div__, sin, etc. +class DTensorProxy(TensorProxy): + _DEFAULT_PREFIX = "dtensor_" + + def __init__( + self, + name=None, + *, + local_tensor=None, + spec=None, + like=None, + shape=None, + device=None, + dtype=None, + requires_grad=False, + grad=None, + prefix=None, + distparallel_type=None, + history=None, + tags=None, + thunder_fsdp_padding_size=None, + ): + super().__init__( + name, + like=like, + shape=shape, + device=device, + dtype=dtype, + requires_grad=requires_grad, + grad=grad, + prefix=prefix, + distparallel_type=distparallel_type, + history=history, + tags=tags, + thunder_fsdp_padding_size=thunder_fsdp_padding_size, + ) + if like is not None: + utils.check_type(like.spec if spec is None else spec, AnyProxy) + utils.check_type(like.local_tensor if local_tensor is None else local_tensor, TensorProxy) + self.spec = like.spec if spec is None else spec + self.local_tensor = like.local_tensor if local_tensor is None else local_tensor + else: + utils.check_type(spec, AnyProxy) + utils.check_type(local_tensor, TensorProxy) + self.spec = spec + self.local_tensor = local_tensor + + def type_string(self): + return f"DTensor {self.device.device_str()} {self.dtype.shortname()}{list(self._shape)} mesh={self.spec._o.mesh}, placements={self.spec._o.placements}" + + def replace(self, **changes): + r"""Return a copy of the TensorProxy object with new values for the specified fields as given to the constructor as arguments. + Valid keyword arguments are ``name``, ``history``, ``shape``, ``dtype``, ``device``, ``requires_grad``, ``distparallel_type``, ``thunder_fsdp_padding_size``. + ``like`` is also a valid keyword and will take metadata from the tensor proxy argument + in preference to the old values but overridable by keyword arguments. + Note that the copy will use the current (environment) tracectx.""" + + like = changes.get("like") + ( + shape, + device, + dtype, + true_dtype, + numel, + ndim, + requires_grad, + grad, + distparallel_type, + thunder_fsdp_padding_size, + ) = _infer_tensor_properties( + like, + changes.get("shape", self._shape if like is None else None), + changes.get("device", self._device if like is None else None), + changes.get("dtype", self._dtype if like is None else None), + changes.get("requires_grad", self._requires_grad if like is None else None), + changes.get("grad", self._grad if like is None else None), + changes.get("distparallel_type", self._distparallel_type if like is None else None), + changes.get("thunder_fsdp_padding_size", self._thunder_fsdp_padding_size if like is None else None), + ) + name = changes.get("name", self.name) + history = changes.get("history", self.history) + tags = changes.get("tags", self.tags) + return DTensorProxy( + name=name, + local_tensor=self.local_tensor, + spec=self.spec, + tags=tags, + shape=shape, + device=device, + dtype=dtype, + requires_grad=requires_grad, + distparallel_type=distparallel_type, + thunder_fsdp_padding_size=thunder_fsdp_padding_size, + history=history, + ) + + +def proxify_dtensor(x, name: str | None = None, history: None | tuple = None) -> DTensorProxy | None: + if isinstance(x, DTensor): + spec_proxy = AnyProxy(x._spec, history=history) + t = x._local_tensor + shape = x.shape + device = devices.to_device(x.device) + dtype = dtypes.to_dtype(x.dtype) + grad = None + distparallel_type = None + _thunder_fsdp_padding_size = None + local_tensor_proxy = proxy(t, history=history) + return DTensorProxy( + name, + local_tensor=local_tensor_proxy, + spec=spec_proxy, + shape=tuple(shape), + device=device, + dtype=dtype, + requires_grad=x.requires_grad, + grad=grad, + distparallel_type=distparallel_type, + history=history, + thunder_fsdp_padding_size=_thunder_fsdp_padding_size, + ) + + return None + + +def create_dtensor_proxy_from_proxies(local_tensor: TensorProxy, spec: AnyProxy, requires_grad: bool) -> DTensorProxy: + """Creates a DTensorProxy from existing TensorProxy and AnyProxy objects. + + This function constructs a distributed tensor proxy by combining a local tensor proxy + with a specification proxy that contains distribution information. + + Args: + local_tensor (TensorProxy): The local tensor proxy representing the distributed tensor's local portion. + spec (AnyProxy): The specification proxy containing distribution information (mesh, placements, etc.). + requires_grad (bool): Whether the tensor requires gradient computation. + + Returns: + DTensorProxy: A new distributed tensor proxy combining the local tensor and distribution spec. + """ + utils.check_type(local_tensor, TensorProxy) + utils.check_type(spec, AnyProxy) + return DTensorProxy( + local_tensor=local_tensor, + spec=spec, + shape=tuple(spec._o.shape), + device=local_tensor.device, + dtype=local_tensor.dtype, + requires_grad=requires_grad, + grad=None, + distparallel_type=None, + thunder_fsdp_padding_size=None, + ) + + +def is_dtensor_proxy(x): + return isinstance(x, DTensorProxy) diff --git a/thunder/torch/experimental/dtensor_torch_and_prims.py b/thunder/torch/experimental/dtensor_torch_and_prims.py new file mode 100644 index 0000000000..194ffc5935 --- /dev/null +++ b/thunder/torch/experimental/dtensor_torch_and_prims.py @@ -0,0 +1,140 @@ +from functools import partial +from collections.abc import Callable + +from thunder.torch import torchsymbol, TensorLike, register_function +import thunder.torch as ltorch +from thunder.core.pytree import tree_flatten +from thunder import clang +from thunder.torch.experimental.dtensor_utils import run_with_fake_tensor +from thunder.torch.experimental.dtensor_proxy import DTensorProxy, create_dtensor_proxy_from_proxies +from thunder.torch.langctx import register_method +from thunder.core.prims import make_prim + +from thunder.core.proxies import TensorProxy, AnyProxy +from thunder.core.transforms import ( + register_grad, + put_grads, + get_grad, +) +from thunder.executors.torchex import ex as pytorchex +from thunder.executors.pythonex import ex as pythonex +from thunder.core.prims import make_prim, OpTags +from thunder.core import prims +from thunder.core import baseutils +from thunder.core import utils +from thunder import clang + +import torch + +dtensor_torchsymbol = partial(torchsymbol, allow_tensor_subclass_proxy=True) + + +def dispatch_to_impl(single_device_symbol, dtensor_symbol): + def wrapper(*args, **kwargs): + filter_tensor_proxies = list(filter(lambda t: isinstance(t, TensorProxy), tree_flatten((args, kwargs))[0])) + # number only variant of the operator. + if not filter_tensor_proxies: + return single_device_symbol(*args, **kwargs) + + dtensor_tensor_proxies = map(lambda t: isinstance(t, DTensorProxy), filter_tensor_proxies) + if all(dtensor_tensor_proxies): + return dtensor_symbol(*args, **kwargs) + else: + return single_device_symbol(*args, **kwargs) + + return wrapper + + +def register_function_for_dtensor(torch_fn, single_device_symbol, dtensor_symbol, is_method=False): + register_function(torch_fn, dispatch_to_impl(single_device_symbol, dtensor_symbol)) + + if is_method: + method_name: str = torch_fn.__name__ + torch_method: None | Callable = getattr(torch.Tensor, method_name, None) + register_method_for_dtensor(torch_method, single_device_symbol, dtensor_symbol) + + +def register_method_for_dtensor(torch_fn, single_device_symbol, dtensor_symbol): + method_wrapper = dispatch_to_impl(single_device_symbol, dtensor_symbol) + register_function(torch_fn, method_wrapper) + register_method(torch_fn.__name__, method_wrapper) + + +def _check_dtensor_spec_repr_meta(s: AnyProxy, value: str) -> None: + baseutils.check_type(s, AnyProxy) + + +check_dtensor_spec_repr = make_prim( + "check_dtensor_spec_repr", + "check_dtensor_spec_repr", + meta=_check_dtensor_spec_repr_meta, + tags=(OpTags.DONT_DCE,), +) + + +def _check_dtensor_spec_repr(s: object, value: str) -> None: + utils.check(repr(s) == value, lambda: f"Expected '{s} to be equal to '{value}") + + +_check_python_repr_impl = pythonex.register_operator( + "check_python_repr", like=check_dtensor_spec_repr, fn=_check_dtensor_spec_repr +) + +pythonex.register_implementation(check_dtensor_spec_repr, _check_python_repr_impl) + + +def handle_check_dtensor_spec_in_prologue(prim, prologue_trace, args) -> bool: + if prim == check_dtensor_spec_repr: + # How does torch.compile guard for this? + a = args[0] + o = AnyProxy(None, prefix="dtensor_spec") + bsym = prims.unpack_attr.bind(a, "_spec", output=o) + prologue_trace.bound_symbols.append(bsym) + check_dtensor_spec_repr(o, repr(args[1])) + + # Also adds metadata check for _local_tensor + t = TensorProxy(like=a.local_tensor, requires_grad=a.local_tensor.requires_grad) + bsym = prims.unpack_attr.bind(a, "_local_tensor", output=t) + prologue_trace.bound_symbols.append(bsym) + clang.check_tensor_shape_and_metadata(t) + return True + + return False + + +def dtensor_mul_meta(a, b): + output = run_with_fake_tensor(torch.mul, a, b) + local_tensor_proxy = TensorProxy(like=a.local_tensor) + spec = output._spec + spec_proxy = AnyProxy(spec, history=a.history) + return create_dtensor_proxy_from_proxies(local_tensor_proxy, spec_proxy, False) + + +dtensor_mul_prim = make_prim("dtensor_mul_prim", "dtensor_mul_prim", meta=dtensor_mul_meta) + +dtensor_mul_prim_impl = pytorchex.register_operator("dtensor_mul_prim", like=dtensor_mul_prim, fn=torch.mul) + +pytorchex.register_implementation(dtensor_mul_prim, dtensor_mul_prim_impl) + + +def _dtensor_mul_prim_grad(a: TensorLike, b: TensorLike) -> TensorLike: + fwd = dtensor_mul_prim(a, b) + + g = get_grad(fwd) + a_grad = dtensor_mul_prim(b, g) + b_grad = dtensor_mul_prim(a, g) + put_grads((a, b), (a_grad, b_grad)) + + return fwd + + +register_grad(dtensor_mul_prim, _dtensor_mul_prim_grad) + + +@dtensor_torchsymbol(torch.mul, id="dtensor.torch.mul") +def dtensor_mul(a: TensorLike, b: TensorLike) -> TensorLike: + return dtensor_mul_prim(a, b) + + +def register_dtensor_torch_and_prims(): + register_function_for_dtensor(torch.mul, ltorch.mul, dtensor_mul, is_method=True) diff --git a/thunder/torch/experimental/dtensor_utils.py b/thunder/torch/experimental/dtensor_utils.py new file mode 100644 index 0000000000..f66c4167ee --- /dev/null +++ b/thunder/torch/experimental/dtensor_utils.py @@ -0,0 +1,103 @@ +from collections.abc import Sequence +from typing import Any + +import torch +from torch.distributed.tensor import DTensor + +from thunder.core.pytree import tree_map +from thunder.core.proxies import TensorProxy, NumberProxy +from thunder.core.devices import to_torch_device +from thunder.core.dtypes import to_torch_dtype +from thunder.core.pytree import tree_map +from thunder.core.trace import TraceCtx +from thunder.torch.experimental.dtensor_proxy import DTensorProxy +from thunder.core.prims import PrimIDs +from thunder.core.symbol import Symbol +from thunder.core.trace import from_trace + +from torch._subclasses.fake_tensor import FakeTensorMode + + +def run_with_fake_tensor(torch_op, *args, **kwargs): + """ + Run a torch operation with fake tensors and return the output. + + Args: + torch_op: The torch operation to execute + *args: Arguments to pass to the torch operation + **kwargs: Keyword arguments to pass to the torch operation + + Returns: + The output of the torch operation executed with fake tensors + """ + + def f(*args, **kwargs): + return torch_op(*args, **kwargs) + + with FakeTensorMode(): + + def materialize_fake_tensors(t): + # Convert proxy types to fake tensors. + if isinstance(t, NumberProxy): + return t.value + + if not isinstance(t, TensorProxy): + return t + + if isinstance(t, DTensorProxy): + i_t = torch.randn( + t.local_tensor.shape, + device=to_torch_device(t.local_tensor.device), + dtype=to_torch_dtype(t.local_tensor.dtype), + ) + return DTensor.from_local(i_t, t.spec._o.device_mesh, t.spec._o.placements) + + return torch.randn(t.shape, device=to_torch_device(t.device), dtype=to_torch_dtype(t.dtype)) + + args, kwargs = tree_map(materialize_fake_tensors, (args, kwargs)) + + return f(*args, **kwargs) + + +def check_dtensor_cotangent_metadata(dtensor, metadata): + if not dtensor._spec == metadata: + raise RuntimeError( + "Metadata (placement and mesh) has changed for cotangent between tracing and runtime" + f"during tracing it was {metadata} but at runtime it is {dtensor._spec}." + ) + + +def check_dtensor_cotangent_metadata_in_backward(bw_trace: TraceCtx): + # NOTE: The metadata (placement and mesh) of the cotangent DTensor + # can be different at runtime than the one we assumed during tracing. + # Because of this, we currently add a check in backward to verify the same. + # However, in future, we should add a symbol which will take care of mapping + # the cotangent metadata at runtime to the cotangent metadata during tracing. + # Also refer: https://github.com/pytorch/pytorch/pull/118670 + + # Quick implementation of a symbol to verify + # that the metadata of the cotangent at runtime as that as during tracing. + check_dtensor_cotangent_metadata_symbol = Symbol( + name="check_dtensor_cotangent_metadata", + meta=lambda dtensor, metadata: None, + python_impl=check_dtensor_cotangent_metadata, + ) + new_bw_trace = from_trace(bw_trace) + new_bsyms = [] + for bsym in bw_trace.bound_symbols: + # Find the `unpack_sequence` for the cotangents. + if bsym.sym.id == PrimIDs.UNPACK_SEQUENCE and bsym.args[0].name == "cotangents": + new_bsyms.append(bsym) + args = bsym.args[0].collection() + for arg in args: + # For every DTensor cotangent, + # add symbol to verify that the metadata is the same as during tracing. + if isinstance(arg, DTensorProxy): + bsym = check_dtensor_cotangent_metadata_symbol.bind(arg, arg.spec._o, output=None) + new_bsyms.append(bsym) + else: + new_bsyms.append(bsym) + + new_bw_trace.bound_symbols = new_bsyms + + return new_bw_trace diff --git a/thunder/transforms/autodiff.py b/thunder/transforms/autodiff.py index 80e6e855e0..deb331d1a4 100644 --- a/thunder/transforms/autodiff.py +++ b/thunder/transforms/autodiff.py @@ -541,6 +541,12 @@ def backward_fn(saved_for_backward, cotangents): prims.python_return(tuple(return_bsym.args[0]["output"])) backward_trace = dce(backward_trace) + + # We only want to apply it on backward trace. + from thunder.torch.experimental.dtensor_utils import check_dtensor_cotangent_metadata_in_backward + + backward_trace = check_dtensor_cotangent_metadata_in_backward(backward_trace) + return forward_trace, backward_trace