-
Notifications
You must be signed in to change notification settings - Fork 643
[Compiler Toolkit] Add option for full inductor. #2150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Compiler Toolkit] Add option for full inductor. #2150
Conversation
Being able to compile fw/bw graphs using compile_fx_inner will help with establishing perf rooflines. Manual testing: NGPU=4 \ CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml \ TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train \ ./run_train.sh \ --model.name compiler_toolkit.llama3 \ --parallelism.data_parallel_shard_degree=2 \ --parallelism.tensor_parallel_degree=2 \ --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config \ --compile.backend inductor <!-- ps-id: 5d590700-6d1f-44fe-8f70-4d2ea39106f4 -->
| assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}" | ||
|
|
||
| # Use Inductor's decomposition table when backend is "inductor" | ||
| decompositions = select_decomp_table() if backend == "inductor" else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is introducing a coupling between frontend and backend.
Decomposition should be per-backend concept, so ideally it should be a internal step of inductor.compile.
| is_forward=True, | ||
| ) | ||
| logger.info("Compiling forward graph with Inductor (compile_fx_inner)") | ||
| return compile_fx_inner(gm, example_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, decomposition would be better applied before or inside compile_fx_inner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SherlockNoMad -- what is the recommended way to apply decompositions on the graph module here? I tried make_fx but it requires to trace the graph again, which fails because the graph module contains ops like graphsafe_run_with_rng_state which it doesn't handle. compile_fx_inner seems to expect the graph to have been decomposed at this point -- and the only way I could get it to work was as part of the AOT full graph capture.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm @SherlockNoMad here's the way I'd picture doing this in compiler toolkit world but let me know what you think:
(1) make decomp_table an argument to joint_graph_builder(). So if someone wants to compile with inductor, they need to perform graph capture with their decomp table set to torch._inductor.decomp. select_decomp_table().
The other option would be to have the user pass in a string backend name like "inductor", where we automatically pick the right decomp table for them. I guess it depends on how automagic we want to make things - and in the toolkit flow, it feels like making the user spell out things like their decomposition_table manually fits a bit more with the goals of compiler toolkit (this is compared to torch.compile, where there is only one API and we choose all of the right defaults automatically, which "just works" but is harder to reason about). If we are worried about users doing the wrong thing, we could raise an error if the user specified compile_fx_inner as their compiler, but forgot to pass in the inductor decomp table.
(2) We'd need to decide on where the decomps should run. I think we have two options:
(2a) as an extra graph pass
(2b) let the decompositions run while we trace out the joint graph
Any preference between the two? They should produce the same result. 2a is more inline with what torch.compile does today. It also has the advantage of being faster (because 2b requires doing an unnecessary second trace, which is roughly ~half the cost of running AOTAutograd).
One argument I could see for doing "decomps as a graph pass" is if the user wants to write a pattern match based graph pass on the graph before decomps run. Say they want to use inductor, but they also want to pattern match aten.rms_norm with their custom kernel.
| logger.info("Compiling backward graph with Inductor (compile_fx_inner)") | ||
| return compile_fx_inner(gm, example_inputs) | ||
|
|
||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can rewrite this to something like to avid code duplication
# apply common passes passes here
gm = common_passses(...)
# apply backend-specific passes
if backend = "inductor":
gm = compile_fx_inner(gm,...)
| model, | ||
| args, | ||
| kwargs, | ||
| decompositions=decompositions, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should be careful about exposing decompositions, since it would change numerics, and is a potential footgun.
SherlockNoMad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm with comments.
| Example: --compile.passes autobucketing_reordering | ||
| """ | ||
|
|
||
| backend: Literal["aot_eager", "inductor"] = "aot_eager" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also we should have a warning about numerics changing behavior when using inductor.
| Example: --compile.passes autobucketing_reordering | ||
| """ | ||
|
|
||
| backend: Literal["aot_eager", "inductor"] = "aot_eager" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason we have this class Compile is to add custom fields that don't exist in torchtitan/config/job_config.py. backend is an existing field so no need to add here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah gotcha! Can remove it.
| passes: List[Callable] = None, dump_folder: str | None = None | ||
| passes: List[Callable] = None, | ||
| dump_folder: str | None = None, | ||
| backend: str = "aot_eager", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if we need to add this new backend argument in the AOT workflow. Since we can introduce full inductor as a graph pass (just like regional inductor).
In this way, we don't need to introduce another "backend knob" and the AOT flow will always be:
eager model -> graph capture -> full graph in ATen IR -> apply graph passes -> optimized graph.
WDTY? @SherlockNoMad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Full inductor compilation using compile_fx_inner is not a FX graph pass. It takes an FX graph and returns a compiled artifact. But it can be exposed to the CLI as a "pass" and made to work similar to the current PR
Being able to compile fw/bw graphs using compile_fx_inner could help with establishing perf rooflines.
Manual testing:
NGPU=4
CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml \ TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train \ ./run_train.sh \ --model.name compiler_toolkit.llama3
--parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=2
--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config \ --compile.backend inductor