Representing DTensor in thunder traces#1907
Conversation
Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com>
… dtensor-init-support
… dtensor-init-support
… dtensor-init-support
… dtensor-init-support
This reverts commit 225f2e3.
… dtensor-init-support
|
Gentle ping @IvanYashchuk |
…5/lightning-thunder into dtensor-init-support
for more information, see https://pre-commit.ci
…5/lightning-thunder into dtensor-init-support
|
With the latest merge, I am seeing failure in the test for Cause - But this PR updates this map to be Workaround - cc: @IvanYashchuk |
for more information, see https://pre-commit.ci
|
Ping @t-vi for stamp |
t-vi
left a comment
There was a problem hiding this comment.
Thank you @kshitij12345 @IvanYashchuk @crcrpar
|
@kshitij12345 you will file an issue about USE_DISTRIBUTED=OFF ? |
|
Opened #2233 for tracking that thunder will work with PyTorch compiled without distributed. |
Fixes #1898
Design Doc - https://docs.google.com/document/d/1Gqb_jXrL-sSqs-D8KrZdcQinxuUSlccZBnnvbYJfYl0/edit?usp=sharing
Changes -
This PR adds support for DTensor inputs to the jitted function. Most of the additions required to support DTensor are present in
thunder/torch/experimentallike theDTensorProxy, related prims.NOTE:
torch.muland no broadcast). Coverage will be followed in subsequent PRs.Following are the main updates:
Prologue: Adds a new primitive
check_dtensor_spec_reprwhich will match the repr ofDTensorSpecof the DTensor in question (see the example below). PR also makes sure that besices theDTensorSpecthere is tensor metadata check for theDTensorobject as well as for the local tensor that it points to. NOTE - Other option for checkingDTensorSpecwould be to keep the inputsDTensorSpecin the TracingContext and prologue could verify for equality.DTensorProxy: Adds a new Proxy object to represent the
DTensor. This class inherits fromTensorProxyasDTensoris a tensor subclass and implements all the same methods that a tensor implements.Prims and Operations: For computation trace, we add prims and torch level operations for DTensor. We add new prims and operations instead of re-using the existing ones to prevent the executors from claiming an operation on DTensor by-mistake.
Representation in trace -
Example Program
Prologue Trace (relevant snippet)
Computation Trace : There is a
torchlevel symboldtensor_mulwhich is decomposed into prims for DTensor operations.Backward Trace (initial trace)
Thank you Masaki, Ivan and Mike for the helpful discussions and guidance!