Skip to content

Conversation

@aditvenk
Copy link

Being able to compile fw/bw graphs using compile_fx_inner could help with establishing perf rooflines.

Manual testing:
NGPU=4
CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml \ TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train \ ./run_train.sh \ --model.name compiler_toolkit.llama3
--parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=2
--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config \ --compile.backend inductor

Being able to compile fw/bw graphs using compile_fx_inner will help with establishing perf rooflines.

Manual testing:
NGPU=4 \
CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml \ TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train \ ./run_train.sh \
--model.name compiler_toolkit.llama3 \
--parallelism.data_parallel_shard_degree=2 \
--parallelism.tensor_parallel_degree=2 \
--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config \ --compile.backend inductor

<!-- ps-id: 5d590700-6d1f-44fe-8f70-4d2ea39106f4 -->
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 13, 2025
assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}"

# Use Inductor's decomposition table when backend is "inductor"
decompositions = select_decomp_table() if backend == "inductor" else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is introducing a coupling between frontend and backend.

Decomposition should be per-backend concept, so ideally it should be a internal step of inductor.compile.

is_forward=True,
)
logger.info("Compiling forward graph with Inductor (compile_fx_inner)")
return compile_fx_inner(gm, example_inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, decomposition would be better applied before or inside compile_fx_inner.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SherlockNoMad -- what is the recommended way to apply decompositions on the graph module here? I tried make_fx but it requires to trace the graph again, which fails because the graph module contains ops like graphsafe_run_with_rng_state which it doesn't handle. compile_fx_inner seems to expect the graph to have been decomposed at this point -- and the only way I could get it to work was as part of the AOT full graph capture.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm @SherlockNoMad here's the way I'd picture doing this in compiler toolkit world but let me know what you think:

(1) make decomp_table an argument to joint_graph_builder(). So if someone wants to compile with inductor, they need to perform graph capture with their decomp table set to torch._inductor.decomp. select_decomp_table().

The other option would be to have the user pass in a string backend name like "inductor", where we automatically pick the right decomp table for them. I guess it depends on how automagic we want to make things - and in the toolkit flow, it feels like making the user spell out things like their decomposition_table manually fits a bit more with the goals of compiler toolkit (this is compared to torch.compile, where there is only one API and we choose all of the right defaults automatically, which "just works" but is harder to reason about). If we are worried about users doing the wrong thing, we could raise an error if the user specified compile_fx_inner as their compiler, but forgot to pass in the inductor decomp table.

(2) We'd need to decide on where the decomps should run. I think we have two options:

(2a) as an extra graph pass
(2b) let the decompositions run while we trace out the joint graph

Any preference between the two? They should produce the same result. 2a is more inline with what torch.compile does today. It also has the advantage of being faster (because 2b requires doing an unnecessary second trace, which is roughly ~half the cost of running AOTAutograd).

One argument I could see for doing "decomps as a graph pass" is if the user wants to write a pattern match based graph pass on the graph before decomps run. Say they want to use inductor, but they also want to pattern match aten.rms_norm with their custom kernel.

logger.info("Compiling backward graph with Inductor (compile_fx_inner)")
return compile_fx_inner(gm, example_inputs)

else:
Copy link
Contributor

@SherlockNoMad SherlockNoMad Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can rewrite this to something like to avid code duplication

    # apply common passes passes here
    gm = common_passses(...)

    # apply backend-specific passes 
    if backend = "inductor": 
        gm = compile_fx_inner(gm,...)

model,
args,
kwargs,
decompositions=decompositions,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be careful about exposing decompositions, since it would change numerics, and is a potential footgun.

Copy link
Contributor

@SherlockNoMad SherlockNoMad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm with comments.

Example: --compile.passes autobucketing_reordering
"""

backend: Literal["aot_eager", "inductor"] = "aot_eager"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we should have a warning about numerics changing behavior when using inductor.

Example: --compile.passes autobucketing_reordering
"""

backend: Literal["aot_eager", "inductor"] = "aot_eager"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we have this class Compile is to add custom fields that don't exist in torchtitan/config/job_config.py. backend is an existing field so no need to add here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

@aditvenk aditvenk Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah gotcha! Can remove it.

passes: List[Callable] = None, dump_folder: str | None = None
passes: List[Callable] = None,
dump_folder: str | None = None,
backend: str = "aot_eager",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if we need to add this new backend argument in the AOT workflow. Since we can introduce full inductor as a graph pass (just like regional inductor).

In this way, we don't need to introduce another "backend knob" and the AOT flow will always be:
eager model -> graph capture -> full graph in ATen IR -> apply graph passes -> optimized graph.

WDTY? @SherlockNoMad

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Full inductor compilation using compile_fx_inner is not a FX graph pass. It takes an FX graph and returns a compiled artifact. But it can be exposed to the CLI as a "pass" and made to work similar to the current PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants