-
Notifications
You must be signed in to change notification settings - Fork 644
[Compiler Toolkit] Add option for full inductor. #2150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| JointWithDescriptors, | ||
| ) | ||
| from torch._guards import tracing, TracingContext | ||
| from torch._inductor.decomposition import select_decomp_table | ||
| from torch.distributed.tensor import DTensor | ||
| from torchtitan.config import JobConfig | ||
| from torchtitan.distributed import ParallelDims | ||
|
|
@@ -37,8 +38,18 @@ def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> No | |
|
|
||
|
|
||
| def export_joint( | ||
| model, args, kwargs=None, dump_folder: str | None = None | ||
| model, args, kwargs=None, dump_folder: str | None = None, decompositions=None | ||
| ) -> tuple[JointWithDescriptors, TracingContext]: | ||
| """ | ||
| Export joint forward-backward graph with AOT Autograd. | ||
|
|
||
| Args: | ||
| model: The model to export | ||
| args: Tuple of input arguments | ||
| kwargs: Dict of keyword arguments for the model | ||
| dump_folder: Optional folder to dump the graph to | ||
| decompositions: Optional decomposition table for AOT Autograd | ||
| """ | ||
| if kwargs is None: | ||
| kwargs = {} | ||
| assert isinstance(args, tuple) | ||
|
|
@@ -62,12 +73,25 @@ def export_joint( | |
|
|
||
| with tracing(tracing_context): | ||
| return ( | ||
| aot_export_joint_with_descriptors_alone(gm, args, kwargs), | ||
| aot_export_joint_with_descriptors_alone( | ||
| gm, args, kwargs, decompositions=decompositions | ||
| ), | ||
| tracing_context, | ||
| ) | ||
|
|
||
|
|
||
| def aot_export_joint_with_descriptors_alone(model, args, kwargs=None): | ||
| def aot_export_joint_with_descriptors_alone( | ||
| model, args, kwargs=None, decompositions=None | ||
| ): | ||
| """ | ||
| Export joint forward-backward graph with AOT Autograd. | ||
|
|
||
| Args: | ||
| model: The model to export | ||
| args: Tuple of input arguments | ||
| kwargs: Dict of keyword arguments for the model | ||
| decompositions: Optional decomposition table for AOT Autograd. | ||
| """ | ||
| if kwargs is None: | ||
| kwargs = {} | ||
| assert isinstance(args, tuple) | ||
|
|
@@ -78,6 +102,7 @@ def aot_export_joint_with_descriptors_alone(model, args, kwargs=None): | |
| model, | ||
| args, | ||
| kwargs, | ||
| decompositions=decompositions, | ||
| ) | ||
| return joint_with_descriptors | ||
|
|
||
|
|
@@ -90,6 +115,7 @@ def joint_graph_builder( | |
| bw_compiler: Optional[Callable] = None, | ||
| joint_custom_passes: Optional[List[Callable]] = None, | ||
| dump_folder: str | None = None, | ||
| backend: str = "aot_eager", | ||
| ): | ||
| """ | ||
| Build a joint forward-backward graph for the model with optional custom compilers. | ||
|
|
@@ -102,16 +128,23 @@ def joint_graph_builder( | |
| bw_compiler: Optional custom backward compiler function | ||
| joint_custom_passes: list of custom passes to run on the joint graph | ||
| dump_folder: Optional folder to dump the graph to | ||
| backend: Compilation backend ("aot_eager", "inductor") | ||
| """ | ||
| assert isinstance(model_args, tuple) | ||
| for idx, arg in enumerate(model_args): | ||
| assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}" | ||
|
|
||
| # Use Inductor's decomposition table when backend is "inductor" | ||
| decompositions = select_decomp_table() if backend == "inductor" else None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is introducing a coupling between frontend and backend. Decomposition should be per-backend concept, so ideally it should be a internal step of inductor.compile. |
||
|
|
||
| # get joint graph | ||
| ( | ||
| joint_with_descriptors, | ||
| tracing_context, | ||
| ) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder) | ||
| (joint_with_descriptors, tracing_context,) = export_joint( | ||
| model, | ||
| model_args, | ||
| model_kwargs, | ||
| dump_folder=dump_folder, | ||
| decompositions=decompositions, | ||
| ) | ||
|
|
||
| # run custom passes on joint-graph before partitioner | ||
| if joint_custom_passes is not None: | ||
|
|
@@ -270,37 +303,70 @@ def compiler( | |
|
|
||
|
|
||
| def make_compiler_with_passes( | ||
| passes: List[Callable] = None, dump_folder: str | None = None | ||
| passes: List[Callable] = None, | ||
| dump_folder: str | None = None, | ||
| backend: str = "aot_eager", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if we need to add this new In this way, we don't need to introduce another "backend knob" and the AOT flow will always be: WDTY? @SherlockNoMad
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Full inductor compilation using |
||
| ): | ||
| """ | ||
| Create forward and backward compilers with specified passes. | ||
| Create forward and backward compilers with specified passes and backend. | ||
|
|
||
| Args: | ||
| passes: List of compiler pass functions to apply. If None, uses DEFAULT_COMPILER_PASSES. | ||
| dump_folder: Optional folder to dump graphs | ||
| backend: Compilation backend ("aot_eager", "inductor") | ||
|
|
||
| Returns: | ||
| Tuple of (fw_compiler, bw_compiler) functions | ||
| """ | ||
| from torch._inductor.compile_fx import compile_fx_inner | ||
|
|
||
| if backend == "inductor": | ||
| # Use compile_fx_inner as the final compiler after applying transformation passes | ||
| def fw_compiler(gm: torch.fx.GraphModule, example_inputs): | ||
| gm = compiler( | ||
| "fwd_gm", | ||
| gm, | ||
| example_inputs, | ||
| passes=passes, | ||
| dump_folder=dump_folder, | ||
| is_forward=True, | ||
| ) | ||
| logger.info("Compiling forward graph with Inductor (compile_fx_inner)") | ||
| return compile_fx_inner(gm, example_inputs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For example, decomposition would be better applied before or inside
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SherlockNoMad -- what is the recommended way to apply decompositions on the graph module here? I tried
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm @SherlockNoMad here's the way I'd picture doing this in compiler toolkit world but let me know what you think: (1) make The other option would be to have the user pass in a string backend name like (2) We'd need to decide on where the decomps should run. I think we have two options: (2a) as an extra graph pass Any preference between the two? They should produce the same result. 2a is more inline with what torch.compile does today. It also has the advantage of being faster (because 2b requires doing an unnecessary second trace, which is roughly ~half the cost of running AOTAutograd). One argument I could see for doing "decomps as a graph pass" is if the user wants to write a pattern match based graph pass on the graph before decomps run. Say they want to use inductor, but they also want to pattern match
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the design spirit of "compiler toolkit" would be "highly customizable, with modular off-the-shelf component. To spell it out
There could be coupling between passes, e.g. "inductor pass" would require "inductor decomp pass" as a pre-requisite.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have tried this in a few different ways but ran into stumbling blocks. Let's discuss offline. |
||
|
|
||
| def bw_compiler(gm: torch.fx.GraphModule, example_inputs): | ||
| gm = compiler( | ||
| "bwd_gm", | ||
| gm, | ||
| example_inputs, | ||
| passes=passes, | ||
| dump_folder=dump_folder, | ||
| is_forward=False, | ||
| ) | ||
| logger.info("Compiling backward graph with Inductor (compile_fx_inner)") | ||
| return compile_fx_inner(gm, example_inputs) | ||
|
|
||
| else: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can rewrite this to something like to avid code duplication |
||
|
|
||
| def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: | ||
| return compiler( | ||
| "fwd_gm", | ||
| gm, | ||
| example_inputs, | ||
| passes=passes, | ||
| dump_folder=dump_folder, | ||
| is_forward=True, | ||
| ) | ||
|
|
||
| def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: | ||
| return compiler( | ||
| "fwd_gm", | ||
| gm, | ||
| example_inputs, | ||
| passes=passes, | ||
| dump_folder=dump_folder, | ||
| is_forward=True, | ||
| ) | ||
|
|
||
| def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: | ||
| return compiler( | ||
| "bwd_gm", | ||
| gm, | ||
| example_inputs, | ||
| passes=passes, | ||
| dump_folder=dump_folder, | ||
| is_forward=False, | ||
| ) | ||
| def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: | ||
| return compiler( | ||
| "bwd_gm", | ||
| gm, | ||
| example_inputs, | ||
| passes=passes, | ||
| dump_folder=dump_folder, | ||
| is_forward=False, | ||
| ) | ||
|
|
||
| return fw_compiler, bw_compiler | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,16 +5,24 @@ | |
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from dataclasses import dataclass, field | ||
| from typing import Literal | ||
|
|
||
|
|
||
| @dataclass | ||
| class Compile: | ||
| """ | ||
| List of compiler pass names to apply in the compiler toolkit workflow. | ||
| By default, no passes are applied. | ||
| Example: --compile.passes autobucketing_reordering,regional_inductor | ||
| Compiler configuration for the compiler toolkit workflow. | ||
| - backend: The compilation backend to use. Options are: | ||
| - "aot_eager": AOT Autograd with eager backend (graph transformations only) | ||
| - "inductor": Full Inductor compilation with optimized code generation | ||
| - passes: List of compiler pass names to apply in the compiler toolkit workflow. | ||
| Example: --compile.passes autobucketing_reordering | ||
| """ | ||
|
|
||
| backend: Literal["aot_eager", "inductor"] = "aot_eager" | ||
aditvenk marked this conversation as resolved.
Show resolved
Hide resolved
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also we should have a warning about numerics changing behavior when using inductor.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason we have this
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah gotcha! Can remove it. |
||
| passes: list[str] = field(default_factory=list) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should be careful about exposing decompositions, since it would change numerics, and is a potential footgun.