Skip to content

[thunderfx] Avoid split at Tensor.__eq__ by registering it in thunder.torch#2211

Merged
mruberry merged 1 commit intomainfrom
no-split-at-builtin-eq
Jun 10, 2025
Merged

[thunderfx] Avoid split at Tensor.__eq__ by registering it in thunder.torch#2211
mruberry merged 1 commit intomainfrom
no-split-at-builtin-eq

Conversation

@crcrpar
Copy link
Copy Markdown
Collaborator

@crcrpar crcrpar commented Jun 10, 2025

What does this PR do?

Fixes the unexpected split at Tensor.__eq__.

Currently backend.subgraph_infors[0].split_graph_module in the snippet bwelow has thunder_0 and inductor_1. The split is at max_abs.__eq__(0) as in the second GraphModule definition (bottom).

import torch
from thunder.dynamo import thunderfx

class GraphModule(torch.nn.Module):
    def forward(self, L_data_hp_: "bf16[4096, 4096][4096, 1]"):
        l_data_hp_ = L_data_hp_
        data_hp: "bf16[524288, 32][32, 1]" = l_data_hp_.reshape(-1, 32);  l_data_hp_ = None
        abs_1: "bf16[524288, 32][32, 1]" = torch.abs(data_hp)
        max_abs: "bf16[524288][1]" = torch.amax(abs_1, 1);  abs_1 = None
        eq: "b8[524288][1]" = max_abs.__eq__(0)
        return eq

if __name__ == "__main__":
    module = GraphModule().cuda()
    L_data_hp_ = torch.randn((4096, 4096), device="cuda", dtype=torch.bfloat16)
    jitted = thunderfx(module)
    out = jitted(L_data_hp_)
    backend = jitted._backend
    backend.subgraph_infos[0].split_graph_module.print_readable()
class GraphModule(torch.nn.Module):
    def forward(self, l_l_data_hp_: "bf16[4096, 4096]"):
        # No stacktrace found for following nodes
        thunder_0 = self.thunder_0(l_l_data_hp_);  l_l_data_hp_ = None
        inductor_1 = self.inductor_1(thunder_0);  thunder_0 = None
        return (inductor_1,)

    class thunder_0(torch.nn.Module):
        def forward(self, l_l_data_hp_: "bf16[4096, 4096]"):
            data_hp: "bf16[524288, 32]" = l_l_data_hp_.reshape(-1, 32);  l_l_data_hp_ = None

            abs_1: "bf16[524288, 32]" = torch.abs(data_hp);  data_hp = None

            max_abs: "bf16[524288]" = torch.amax(abs_1, 1);  abs_1 = None
            return max_abs

        class _model(torch.nn.Module):
            def forward(self, l_l_data_hp_: "bf16[4096, 4096]"):
                data_hp: "bf16[524288, 32]" = l_l_data_hp_.reshape(-1, 32);  l_l_data_hp_ = None

                abs_1: "bf16[524288, 32]" = torch.abs(data_hp);  data_hp = None

                max_abs: "bf16[524288]" = torch.amax(abs_1, 1);  abs_1 = None
                return max_abs

    class inductor_1(torch.nn.Module):
        def forward(self, max_abs: "bf16[524288]"):
            eq: "b8[524288]" = max_abs.__eq__(0);  max_abs = None
            return eq

        class _orig_mod(torch.nn.Module):
            def forward(self, max_abs: "bf16[524288]"):
                eq: "b8[524288]" = max_abs.__eq__(0);  max_abs = None
                return eq

Currently `backend.subgraph_infors[0].split_graph_module` in the snippet bwelow
has `thunder_0` and `inductor_1`. The split is at `max_abs.__eq__(0)` as
in the second `GraphModule` definition (bottom).
```python
import torch
from thunder.dynamo import thunderfx

class GraphModule(torch.nn.Module):
    def forward(self, L_data_hp_: "bf16[4096, 4096][4096, 1]"):
        l_data_hp_ = L_data_hp_
        data_hp: "bf16[524288, 32][32, 1]" = l_data_hp_.reshape(-1, 32);  l_data_hp_ = None
        abs_1: "bf16[524288, 32][32, 1]" = torch.abs(data_hp)
        max_abs: "bf16[524288][1]" = torch.amax(abs_1, 1);  abs_1 = None
        eq: "b8[524288][1]" = max_abs.__eq__(0)
        return eq

if __name__ == "__main__":
    module = GraphModule().cuda()
    L_data_hp_ = torch.randn((4096, 4096), device="cuda", dtype=torch.bfloat16)
    jitted = thunderfx(module)
    out = jitted(L_data_hp_)
    backend = jitted._backend
    backend.subgraph_infos[0].split_graph_module.print_readable()
```

```python
class GraphModule(torch.nn.Module):
    def forward(self, l_l_data_hp_: "bf16[4096, 4096]"):
        # No stacktrace found for following nodes
        thunder_0 = self.thunder_0(l_l_data_hp_);  l_l_data_hp_ = None
        inductor_1 = self.inductor_1(thunder_0);  thunder_0 = None
        return (inductor_1,)

    class thunder_0(torch.nn.Module):
        def forward(self, l_l_data_hp_: "bf16[4096, 4096]"):
            data_hp: "bf16[524288, 32]" = l_l_data_hp_.reshape(-1, 32);  l_l_data_hp_ = None

            abs_1: "bf16[524288, 32]" = torch.abs(data_hp);  data_hp = None

            max_abs: "bf16[524288]" = torch.amax(abs_1, 1);  abs_1 = None
            return max_abs

        class _model(torch.nn.Module):
            def forward(self, l_l_data_hp_: "bf16[4096, 4096]"):
                data_hp: "bf16[524288, 32]" = l_l_data_hp_.reshape(-1, 32);  l_l_data_hp_ = None

                abs_1: "bf16[524288, 32]" = torch.abs(data_hp);  data_hp = None

                max_abs: "bf16[524288]" = torch.amax(abs_1, 1);  abs_1 = None
                return max_abs

    class inductor_1(torch.nn.Module):
        def forward(self, max_abs: "bf16[524288]"):
            eq: "b8[524288]" = max_abs.__eq__(0);  max_abs = None
            return eq

        class _orig_mod(torch.nn.Module):
            def forward(self, max_abs: "bf16[524288]"):
                eq: "b8[524288]" = max_abs.__eq__(0);  max_abs = None
                return eq
```

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar crcrpar added the thunderfx for things that could be applicable to the dynamo+thunder frontend label Jun 10, 2025
Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry mruberry enabled auto-merge (squash) June 10, 2025 15:36
@mruberry mruberry merged commit 62d9535 into main Jun 10, 2025
50 checks passed
@mruberry mruberry deleted the no-split-at-builtin-eq branch June 10, 2025 15:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

thunderfx for things that could be applicable to the dynamo+thunder frontend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants