[thunderfx] Avoid split at Tensor.__eq__ by registering it in thunder.torch#2211
Merged
[thunderfx] Avoid split at Tensor.__eq__ by registering it in thunder.torch#2211
Tensor.__eq__ by registering it in thunder.torch#2211Conversation
Currently `backend.subgraph_infors[0].split_graph_module` in the snippet bwelow
has `thunder_0` and `inductor_1`. The split is at `max_abs.__eq__(0)` as
in the second `GraphModule` definition (bottom).
```python
import torch
from thunder.dynamo import thunderfx
class GraphModule(torch.nn.Module):
def forward(self, L_data_hp_: "bf16[4096, 4096][4096, 1]"):
l_data_hp_ = L_data_hp_
data_hp: "bf16[524288, 32][32, 1]" = l_data_hp_.reshape(-1, 32); l_data_hp_ = None
abs_1: "bf16[524288, 32][32, 1]" = torch.abs(data_hp)
max_abs: "bf16[524288][1]" = torch.amax(abs_1, 1); abs_1 = None
eq: "b8[524288][1]" = max_abs.__eq__(0)
return eq
if __name__ == "__main__":
module = GraphModule().cuda()
L_data_hp_ = torch.randn((4096, 4096), device="cuda", dtype=torch.bfloat16)
jitted = thunderfx(module)
out = jitted(L_data_hp_)
backend = jitted._backend
backend.subgraph_infos[0].split_graph_module.print_readable()
```
```python
class GraphModule(torch.nn.Module):
def forward(self, l_l_data_hp_: "bf16[4096, 4096]"):
# No stacktrace found for following nodes
thunder_0 = self.thunder_0(l_l_data_hp_); l_l_data_hp_ = None
inductor_1 = self.inductor_1(thunder_0); thunder_0 = None
return (inductor_1,)
class thunder_0(torch.nn.Module):
def forward(self, l_l_data_hp_: "bf16[4096, 4096]"):
data_hp: "bf16[524288, 32]" = l_l_data_hp_.reshape(-1, 32); l_l_data_hp_ = None
abs_1: "bf16[524288, 32]" = torch.abs(data_hp); data_hp = None
max_abs: "bf16[524288]" = torch.amax(abs_1, 1); abs_1 = None
return max_abs
class _model(torch.nn.Module):
def forward(self, l_l_data_hp_: "bf16[4096, 4096]"):
data_hp: "bf16[524288, 32]" = l_l_data_hp_.reshape(-1, 32); l_l_data_hp_ = None
abs_1: "bf16[524288, 32]" = torch.abs(data_hp); data_hp = None
max_abs: "bf16[524288]" = torch.amax(abs_1, 1); abs_1 = None
return max_abs
class inductor_1(torch.nn.Module):
def forward(self, max_abs: "bf16[524288]"):
eq: "b8[524288]" = max_abs.__eq__(0); max_abs = None
return eq
class _orig_mod(torch.nn.Module):
def forward(self, max_abs: "bf16[524288]"):
eq: "b8[524288]" = max_abs.__eq__(0); max_abs = None
return eq
```
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

What does this PR do?
Fixes the unexpected split at
Tensor.__eq__.Currently
backend.subgraph_infors[0].split_graph_modulein the snippet bwelow hasthunder_0andinductor_1. The split is atmax_abs.__eq__(0)as in the secondGraphModuledefinition (bottom).