diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 011bbe402a..dbfe724b11 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -76,15 +76,18 @@ def parallelize_deepseekv3( with disable_compile(job_config): model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config) + # Get backend from config + backend = job_config.compile.backend + # Get joint custom passes from config joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config) # Get compiler passes from config compiler_passes = get_compiler_passes_from_config(model, job_config) - # Create compilers with specified passes (defaults to no passes) + # Create compilers with specified passes and backend fw_compiler, bw_compiler = make_compiler_with_passes( - compiler_passes, dump_folder=job_config.job.dump_folder + compiler_passes, dump_folder=job_config.job.dump_folder, backend=backend ) # Create custom joint_graph_builder with deepseekv3-specific compilers @@ -94,6 +97,7 @@ def parallelize_deepseekv3( bw_compiler=bw_compiler, joint_custom_passes=joint_custom_passes, dump_folder=job_config.job.dump_folder, + backend=backend, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index e097579cc0..680ee17f4e 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -17,6 +17,7 @@ JointWithDescriptors, ) from torch._guards import tracing, TracingContext +from torch._inductor.decomposition import select_decomp_table from torch.distributed.tensor import DTensor from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims @@ -37,8 +38,18 @@ def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> No def export_joint( - model, args, kwargs=None, dump_folder: str | None = None + model, args, kwargs=None, dump_folder: str | None = None, decompositions=None ) -> tuple[JointWithDescriptors, TracingContext]: + """ + Export joint forward-backward graph with AOT Autograd. + + Args: + model: The model to export + args: Tuple of input arguments + kwargs: Dict of keyword arguments for the model + dump_folder: Optional folder to dump the graph to + decompositions: Optional decomposition table for AOT Autograd + """ if kwargs is None: kwargs = {} assert isinstance(args, tuple) @@ -62,12 +73,25 @@ def export_joint( with tracing(tracing_context): return ( - aot_export_joint_with_descriptors_alone(gm, args, kwargs), + aot_export_joint_with_descriptors_alone( + gm, args, kwargs, decompositions=decompositions + ), tracing_context, ) -def aot_export_joint_with_descriptors_alone(model, args, kwargs=None): +def aot_export_joint_with_descriptors_alone( + model, args, kwargs=None, decompositions=None +): + """ + Export joint forward-backward graph with AOT Autograd. + + Args: + model: The model to export + args: Tuple of input arguments + kwargs: Dict of keyword arguments for the model + decompositions: Optional decomposition table for AOT Autograd. + """ if kwargs is None: kwargs = {} assert isinstance(args, tuple) @@ -78,6 +102,7 @@ def aot_export_joint_with_descriptors_alone(model, args, kwargs=None): model, args, kwargs, + decompositions=decompositions, ) return joint_with_descriptors @@ -90,6 +115,7 @@ def joint_graph_builder( bw_compiler: Optional[Callable] = None, joint_custom_passes: Optional[List[Callable]] = None, dump_folder: str | None = None, + backend: str = "aot_eager", ): """ Build a joint forward-backward graph for the model with optional custom compilers. @@ -102,16 +128,23 @@ def joint_graph_builder( bw_compiler: Optional custom backward compiler function joint_custom_passes: list of custom passes to run on the joint graph dump_folder: Optional folder to dump the graph to + backend: Compilation backend ("aot_eager", "inductor") """ assert isinstance(model_args, tuple) for idx, arg in enumerate(model_args): assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}" + # Use Inductor's decomposition table when backend is "inductor" + decompositions = select_decomp_table() if backend == "inductor" else None + # get joint graph - ( - joint_with_descriptors, - tracing_context, - ) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder) + (joint_with_descriptors, tracing_context,) = export_joint( + model, + model_args, + model_kwargs, + dump_folder=dump_folder, + decompositions=decompositions, + ) # run custom passes on joint-graph before partitioner if joint_custom_passes is not None: @@ -270,37 +303,70 @@ def compiler( def make_compiler_with_passes( - passes: List[Callable] = None, dump_folder: str | None = None + passes: List[Callable] = None, + dump_folder: str | None = None, + backend: str = "aot_eager", ): """ - Create forward and backward compilers with specified passes. + Create forward and backward compilers with specified passes and backend. Args: passes: List of compiler pass functions to apply. If None, uses DEFAULT_COMPILER_PASSES. + dump_folder: Optional folder to dump graphs + backend: Compilation backend ("aot_eager", "inductor") Returns: Tuple of (fw_compiler, bw_compiler) functions """ + from torch._inductor.compile_fx import compile_fx_inner + + if backend == "inductor": + # Use compile_fx_inner as the final compiler after applying transformation passes + def fw_compiler(gm: torch.fx.GraphModule, example_inputs): + gm = compiler( + "fwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=True, + ) + logger.info("Compiling forward graph with Inductor (compile_fx_inner)") + return compile_fx_inner(gm, example_inputs) + + def bw_compiler(gm: torch.fx.GraphModule, example_inputs): + gm = compiler( + "bwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=False, + ) + logger.info("Compiling backward graph with Inductor (compile_fx_inner)") + return compile_fx_inner(gm, example_inputs) + + else: + + def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: + return compiler( + "fwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=True, + ) - def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler( - "fwd_gm", - gm, - example_inputs, - passes=passes, - dump_folder=dump_folder, - is_forward=True, - ) - - def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler( - "bwd_gm", - gm, - example_inputs, - passes=passes, - dump_folder=dump_folder, - is_forward=False, - ) + def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: + return compiler( + "bwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=False, + ) return fw_compiler, bw_compiler diff --git a/torchtitan/experiments/compiler_toolkit/job_config.py b/torchtitan/experiments/compiler_toolkit/job_config.py index ec5829a6c9..0d46d0719b 100644 --- a/torchtitan/experiments/compiler_toolkit/job_config.py +++ b/torchtitan/experiments/compiler_toolkit/job_config.py @@ -5,16 +5,24 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field +from typing import Literal @dataclass class Compile: """ - List of compiler pass names to apply in the compiler toolkit workflow. - By default, no passes are applied. - Example: --compile.passes autobucketing_reordering,regional_inductor + Compiler configuration for the compiler toolkit workflow. + + - backend: The compilation backend to use. Options are: + - "aot_eager": AOT Autograd with eager backend (graph transformations only) + - "inductor": Full Inductor compilation with optimized code generation + + - passes: List of compiler pass names to apply in the compiler toolkit workflow. + + Example: --compile.passes autobucketing_reordering """ + backend: Literal["aot_eager", "inductor"] = "aot_eager" passes: list[str] = field(default_factory=list) diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index 68fa7443f4..2939db4282 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -63,15 +63,18 @@ def parallelize_llama( with disable_compile(job_config): model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config) + # Get backend from config + backend = job_config.compile.backend + # Get joint custom passes from config joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config) # Get compiler passes from config compiler_passes = get_compiler_passes_from_config(model, job_config) - # Create compilers with specified passes (defaults to no passes) + # Create compilers with specified passes and backend fw_compiler, bw_compiler = make_compiler_with_passes( - compiler_passes, dump_folder=job_config.job.dump_folder + compiler_passes, dump_folder=job_config.job.dump_folder, backend=backend ) # Create custom joint_graph_builder with llama-specific compilers @@ -81,6 +84,7 @@ def parallelize_llama( bw_compiler=bw_compiler, joint_custom_passes=joint_custom_passes, dump_folder=job_config.job.dump_folder, + backend=backend, ) # TODO: CompiledModule should take sample input as well, so that we can