Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
have splitter see __eq__ as supported
Currently `backend.subgraph_infors[0].split_graph_module` in the snippet bwelow
has `thunder_0` and `inductor_1`. The split is at `max_abs.__eq__(0)` as
in the second `GraphModule` definition (bottom).
```python
import torch
from thunder.dynamo import thunderfx

class GraphModule(torch.nn.Module):
    def forward(self, L_data_hp_: "bf16[4096, 4096][4096, 1]"):
        l_data_hp_ = L_data_hp_
        data_hp: "bf16[524288, 32][32, 1]" = l_data_hp_.reshape(-1, 32);  l_data_hp_ = None
        abs_1: "bf16[524288, 32][32, 1]" = torch.abs(data_hp)
        max_abs: "bf16[524288][1]" = torch.amax(abs_1, 1);  abs_1 = None
        eq: "b8[524288][1]" = max_abs.__eq__(0)
        return eq

if __name__ == "__main__":
    module = GraphModule().cuda()
    L_data_hp_ = torch.randn((4096, 4096), device="cuda", dtype=torch.bfloat16)
    jitted = thunderfx(module)
    out = jitted(L_data_hp_)
    backend = jitted._backend
    backend.subgraph_infos[0].split_graph_module.print_readable()
```

```python
class GraphModule(torch.nn.Module):
    def forward(self, l_l_data_hp_: "bf16[4096, 4096]"):
        # No stacktrace found for following nodes
        thunder_0 = self.thunder_0(l_l_data_hp_);  l_l_data_hp_ = None
        inductor_1 = self.inductor_1(thunder_0);  thunder_0 = None
        return (inductor_1,)

    class thunder_0(torch.nn.Module):
        def forward(self, l_l_data_hp_: "bf16[4096, 4096]"):
            data_hp: "bf16[524288, 32]" = l_l_data_hp_.reshape(-1, 32);  l_l_data_hp_ = None

            abs_1: "bf16[524288, 32]" = torch.abs(data_hp);  data_hp = None

            max_abs: "bf16[524288]" = torch.amax(abs_1, 1);  abs_1 = None
            return max_abs

        class _model(torch.nn.Module):
            def forward(self, l_l_data_hp_: "bf16[4096, 4096]"):
                data_hp: "bf16[524288, 32]" = l_l_data_hp_.reshape(-1, 32);  l_l_data_hp_ = None

                abs_1: "bf16[524288, 32]" = torch.abs(data_hp);  data_hp = None

                max_abs: "bf16[524288]" = torch.amax(abs_1, 1);  abs_1 = None
                return max_abs

    class inductor_1(torch.nn.Module):
        def forward(self, max_abs: "bf16[524288]"):
            eq: "b8[524288]" = max_abs.__eq__(0);  max_abs = None
            return eq

        class _orig_mod(torch.nn.Module):
            def forward(self, max_abs: "bf16[524288]"):
                eq: "b8[524288]" = max_abs.__eq__(0);  max_abs = None
                return eq
```

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Jun 10, 2025
commit d62deb147d609dd375f0367bf1c286d21710dbae
2 changes: 1 addition & 1 deletion thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2245,7 +2245,7 @@ def div_(
return _copy_(a, div(a, b))


@torchsymbol(torch.eq, is_method=True)
@torchsymbol(torch.eq, torch.Tensor.__eq__, is_method=True)
def eq(a, b, /):
return clang.eq(a, b)

Expand Down
Loading