Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,18 @@ def parallelize_deepseekv3(
with disable_compile(job_config):
model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config)

# Get backend from config
backend = job_config.compile.backend

# Get joint custom passes from config
joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config)

# Get compiler passes from config
compiler_passes = get_compiler_passes_from_config(model, job_config)

# Create compilers with specified passes (defaults to no passes)
# Create compilers with specified passes and backend
fw_compiler, bw_compiler = make_compiler_with_passes(
compiler_passes, dump_folder=job_config.job.dump_folder
compiler_passes, dump_folder=job_config.job.dump_folder, backend=backend
)

# Create custom joint_graph_builder with deepseekv3-specific compilers
Expand All @@ -94,6 +97,7 @@ def parallelize_deepseekv3(
bw_compiler=bw_compiler,
joint_custom_passes=joint_custom_passes,
dump_folder=job_config.job.dump_folder,
backend=backend,
)

# TODO: CompiledModule should take sample input as well, so that we can
Expand Down
122 changes: 94 additions & 28 deletions torchtitan/experiments/compiler_toolkit/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
JointWithDescriptors,
)
from torch._guards import tracing, TracingContext
from torch._inductor.decomposition import select_decomp_table
from torch.distributed.tensor import DTensor
from torchtitan.config import JobConfig
from torchtitan.distributed import ParallelDims
Expand All @@ -37,8 +38,18 @@ def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> No


def export_joint(
model, args, kwargs=None, dump_folder: str | None = None
model, args, kwargs=None, dump_folder: str | None = None, decompositions=None
) -> tuple[JointWithDescriptors, TracingContext]:
"""
Export joint forward-backward graph with AOT Autograd.

Args:
model: The model to export
args: Tuple of input arguments
kwargs: Dict of keyword arguments for the model
dump_folder: Optional folder to dump the graph to
decompositions: Optional decomposition table for AOT Autograd
"""
if kwargs is None:
kwargs = {}
assert isinstance(args, tuple)
Expand All @@ -62,12 +73,25 @@ def export_joint(

with tracing(tracing_context):
return (
aot_export_joint_with_descriptors_alone(gm, args, kwargs),
aot_export_joint_with_descriptors_alone(
gm, args, kwargs, decompositions=decompositions
),
tracing_context,
)


def aot_export_joint_with_descriptors_alone(model, args, kwargs=None):
def aot_export_joint_with_descriptors_alone(
model, args, kwargs=None, decompositions=None
):
"""
Export joint forward-backward graph with AOT Autograd.

Args:
model: The model to export
args: Tuple of input arguments
kwargs: Dict of keyword arguments for the model
decompositions: Optional decomposition table for AOT Autograd.
"""
if kwargs is None:
kwargs = {}
assert isinstance(args, tuple)
Expand All @@ -78,6 +102,7 @@ def aot_export_joint_with_descriptors_alone(model, args, kwargs=None):
model,
args,
kwargs,
decompositions=decompositions,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be careful about exposing decompositions, since it would change numerics, and is a potential footgun.

)
return joint_with_descriptors

Expand All @@ -90,6 +115,7 @@ def joint_graph_builder(
bw_compiler: Optional[Callable] = None,
joint_custom_passes: Optional[List[Callable]] = None,
dump_folder: str | None = None,
backend: str = "aot_eager",
):
"""
Build a joint forward-backward graph for the model with optional custom compilers.
Expand All @@ -102,16 +128,23 @@ def joint_graph_builder(
bw_compiler: Optional custom backward compiler function
joint_custom_passes: list of custom passes to run on the joint graph
dump_folder: Optional folder to dump the graph to
backend: Compilation backend ("aot_eager", "inductor")
"""
assert isinstance(model_args, tuple)
for idx, arg in enumerate(model_args):
assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}"

# Use Inductor's decomposition table when backend is "inductor"
decompositions = select_decomp_table() if backend == "inductor" else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is introducing a coupling between frontend and backend.

Decomposition should be per-backend concept, so ideally it should be a internal step of inductor.compile.


# get joint graph
(
joint_with_descriptors,
tracing_context,
) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder)
(joint_with_descriptors, tracing_context,) = export_joint(
model,
model_args,
model_kwargs,
dump_folder=dump_folder,
decompositions=decompositions,
)

# run custom passes on joint-graph before partitioner
if joint_custom_passes is not None:
Expand Down Expand Up @@ -270,37 +303,70 @@ def compiler(


def make_compiler_with_passes(
passes: List[Callable] = None, dump_folder: str | None = None
passes: List[Callable] = None,
dump_folder: str | None = None,
backend: str = "aot_eager",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if we need to add this new backend argument in the AOT workflow. Since we can introduce full inductor as a graph pass (just like regional inductor).

In this way, we don't need to introduce another "backend knob" and the AOT flow will always be:
eager model -> graph capture -> full graph in ATen IR -> apply graph passes -> optimized graph.

WDTY? @SherlockNoMad

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Full inductor compilation using compile_fx_inner is not a FX graph pass. It takes an FX graph and returns a compiled artifact. But it can be exposed to the CLI as a "pass" and made to work similar to the current PR

):
"""
Create forward and backward compilers with specified passes.
Create forward and backward compilers with specified passes and backend.

Args:
passes: List of compiler pass functions to apply. If None, uses DEFAULT_COMPILER_PASSES.
dump_folder: Optional folder to dump graphs
backend: Compilation backend ("aot_eager", "inductor")

Returns:
Tuple of (fw_compiler, bw_compiler) functions
"""
from torch._inductor.compile_fx import compile_fx_inner

if backend == "inductor":
# Use compile_fx_inner as the final compiler after applying transformation passes
def fw_compiler(gm: torch.fx.GraphModule, example_inputs):
gm = compiler(
"fwd_gm",
gm,
example_inputs,
passes=passes,
dump_folder=dump_folder,
is_forward=True,
)
logger.info("Compiling forward graph with Inductor (compile_fx_inner)")
return compile_fx_inner(gm, example_inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, decomposition would be better applied before or inside compile_fx_inner.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SherlockNoMad -- what is the recommended way to apply decompositions on the graph module here? I tried make_fx but it requires to trace the graph again, which fails because the graph module contains ops like graphsafe_run_with_rng_state which it doesn't handle. compile_fx_inner seems to expect the graph to have been decomposed at this point -- and the only way I could get it to work was as part of the AOT full graph capture.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm @SherlockNoMad here's the way I'd picture doing this in compiler toolkit world but let me know what you think:

(1) make decomp_table an argument to joint_graph_builder(). So if someone wants to compile with inductor, they need to perform graph capture with their decomp table set to torch._inductor.decomp. select_decomp_table().

The other option would be to have the user pass in a string backend name like "inductor", where we automatically pick the right decomp table for them. I guess it depends on how automagic we want to make things - and in the toolkit flow, it feels like making the user spell out things like their decomposition_table manually fits a bit more with the goals of compiler toolkit (this is compared to torch.compile, where there is only one API and we choose all of the right defaults automatically, which "just works" but is harder to reason about). If we are worried about users doing the wrong thing, we could raise an error if the user specified compile_fx_inner as their compiler, but forgot to pass in the inductor decomp table.

(2) We'd need to decide on where the decomps should run. I think we have two options:

(2a) as an extra graph pass
(2b) let the decompositions run while we trace out the joint graph

Any preference between the two? They should produce the same result. 2a is more inline with what torch.compile does today. It also has the advantage of being faster (because 2b requires doing an unnecessary second trace, which is roughly ~half the cost of running AOTAutograd).

One argument I could see for doing "decomps as a graph pass" is if the user wants to write a pattern match based graph pass on the graph before decomps run. Say they want to use inductor, but they also want to pattern match aten.rms_norm with their custom kernel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. In the spirit of "frontier users wishing to control and customize everything", I think "explicit decomp_table" is better then automagic "inductor" string backend. User can always explicitly spell out torch._inductor.decomp. select_decomp_table(). to choose inductor decomp.

  2. where to apply. I think 2a is better, as it's more flexible and explicit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the design spirit of "compiler toolkit" would be "highly customizable, with modular off-the-shelf component.

To spell it out

  • graph pass pipeline is fully customizable. User can craft any graph transformations in its own sequence.
  • off-the-shelf component would be: "decomposition pass", "inductor compile pass" ...

There could be coupling between passes, e.g. "inductor pass" would require "inductor decomp pass" as a pre-requisite.
But I do wanna minimize frontend-backend coupling, e.g. use of "inductor backend" shouldn't require export frontend to use a particular config.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried this in a few different ways but ran into stumbling blocks. Let's discuss offline.


def bw_compiler(gm: torch.fx.GraphModule, example_inputs):
gm = compiler(
"bwd_gm",
gm,
example_inputs,
passes=passes,
dump_folder=dump_folder,
is_forward=False,
)
logger.info("Compiling backward graph with Inductor (compile_fx_inner)")
return compile_fx_inner(gm, example_inputs)

else:
Copy link
Contributor

@SherlockNoMad SherlockNoMad Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can rewrite this to something like to avid code duplication

    # apply common passes passes here
    gm = common_passses(...)

    # apply backend-specific passes 
    if backend = "inductor": 
        gm = compile_fx_inner(gm,...)


def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler(
"fwd_gm",
gm,
example_inputs,
passes=passes,
dump_folder=dump_folder,
is_forward=True,
)

def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler(
"fwd_gm",
gm,
example_inputs,
passes=passes,
dump_folder=dump_folder,
is_forward=True,
)

def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler(
"bwd_gm",
gm,
example_inputs,
passes=passes,
dump_folder=dump_folder,
is_forward=False,
)
def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler(
"bwd_gm",
gm,
example_inputs,
passes=passes,
dump_folder=dump_folder,
is_forward=False,
)

return fw_compiler, bw_compiler

Expand Down
14 changes: 11 additions & 3 deletions torchtitan/experiments/compiler_toolkit/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,24 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from typing import Literal


@dataclass
class Compile:
"""
List of compiler pass names to apply in the compiler toolkit workflow.
By default, no passes are applied.
Example: --compile.passes autobucketing_reordering,regional_inductor
Compiler configuration for the compiler toolkit workflow.
- backend: The compilation backend to use. Options are:
- "aot_eager": AOT Autograd with eager backend (graph transformations only)
- "inductor": Full Inductor compilation with optimized code generation
- passes: List of compiler pass names to apply in the compiler toolkit workflow.
Example: --compile.passes autobucketing_reordering
"""

backend: Literal["aot_eager", "inductor"] = "aot_eager"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we should have a warning about numerics changing behavior when using inductor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we have this class Compile is to add custom fields that don't exist in torchtitan/config/job_config.py. backend is an existing field so no need to add here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

@aditvenk aditvenk Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah gotcha! Can remove it.

passes: list[str] = field(default_factory=list)


Expand Down
8 changes: 6 additions & 2 deletions torchtitan/experiments/compiler_toolkit/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,18 @@ def parallelize_llama(
with disable_compile(job_config):
model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config)

# Get backend from config
backend = job_config.compile.backend

# Get joint custom passes from config
joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config)

# Get compiler passes from config
compiler_passes = get_compiler_passes_from_config(model, job_config)

# Create compilers with specified passes (defaults to no passes)
# Create compilers with specified passes and backend
fw_compiler, bw_compiler = make_compiler_with_passes(
compiler_passes, dump_folder=job_config.job.dump_folder
compiler_passes, dump_folder=job_config.job.dump_folder, backend=backend
)

# Create custom joint_graph_builder with llama-specific compilers
Expand All @@ -81,6 +84,7 @@ def parallelize_llama(
bw_compiler=bw_compiler,
joint_custom_passes=joint_custom_passes,
dump_folder=job_config.job.dump_folder,
backend=backend,
)

# TODO: CompiledModule should take sample input as well, so that we can
Expand Down
Loading