Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add docstrings for recipes
  • Loading branch information
KaelanDt committed Jun 4, 2025
commit 6b63c19878cb13bb98cdd526d454cb628dfda119
10 changes: 10 additions & 0 deletions docs/source/reference/recipes/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
.. module:: thunder.recipes

thunder.recipes
==================

.. autosummary::
:toctree: generated/

BaseRecipe
HFTransformers
49 changes: 49 additions & 0 deletions thunder/recipes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@ def get_nvfuser_package_hint() -> str:


class BaseRecipe(Recipe):
"""
Compilation recipe with Thunder defaults. The recipe wires a set of executors, transforms
and debug options, while providing a single switch to pick the
fusion backend (“nvfuser” or “torch.compile”). Should be used as a template to extend.

Args:
show_progress: bool, default False
Print interpreter-side progress bars.
fuser: {"nvfuser", "torch.compile"}, default "nvfuser"
Fusion backend to register. Adjusts ``self.executor_names`` so the
chosen backend is present and any mutually-exclusive one is removed.
interpreter: str, default "thunder.jit"
Interpreter identifier forwarded to :class:`Recipe`.
plugins: Iterable | None
Extra Thunder plugins to enable.
"""

def __init__(
self,
show_progress=False,
Expand All @@ -65,16 +82,37 @@ def __init__(
self.show_progress = show_progress

def setup_config(self) -> dict[str, Any]:
"""
Build the per-run configuration dictionary.


Returns:
dict[str, Any]: ``{}`` when ``show_progress`` is *False*;
otherwise ``{"debug_options": DebugOptions(show_interpreter_progress=True)}``.
"""
if not self.show_progress:
return {}
return dict(debug_options=DebugOptions(show_interpreter_progress=True))

def setup_transforms(self) -> list[Transform]:
"""
Constructs the list of graph-level transforms.

Returns:
list[Transform]: Currently ``[PrunePrologueChecks()]``; extend as needed.
"""
transforms = [PrunePrologueChecks()]

return transforms

def setup_fuser(self) -> None:
"""
Reconciles the requested fusion backend with ``self.executor_names``.

Raises:
NotImplementedError: If *fuser* is not ``"nvfuser"`` or ``"torch.compile"``.
"""

if self.fuser == "nvfuser":
if "nvfuser" not in self.executor_names:
self.executor_names.append("nvfuser")
Expand All @@ -89,6 +127,17 @@ def setup_fuser(self) -> None:
)

def setup_executors(self) -> list[Executor]:
"""
Resolves executor names to concrete :class:`Executor` objects.

Returns:
list[Executor]: Instantiated executors in the order given by
``self.executor_names``.

Raises:
TypeError: If ``self.executor_names`` is not a list.
ValueError: If a non-nvfuser executor cannot be found.
"""
if not isinstance(self.executor_names, list):
raise TypeError(
f"self.executor_names must be a list of executor names, got {type(self.executor_names).__name__}"
Expand Down
52 changes: 52 additions & 0 deletions thunder/recipes/hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def transform_traces_pre_prologue(self, pro, comp, epi, **kwargs):


class HFTransformers(BaseRecipe):
"""
Recipe tuned for Hugging Face ``transformers`` models.

Args:
show_progress (bool, optional): Forwarded to :class:`BaseRecipe`.
interpreter (str, optional): Thunder interpreter to use.
plugins (Iterable | None, optional): Extra Thunder plugins.
"""

def __init__(
self,
show_progress=False,
Expand All @@ -52,6 +61,16 @@ def __init__(

@classmethod
def validate(cls, model):
"""
Emit warnings (or errors) if *model* falls outside the supported
transformer versions or base classes.

Args:
model (transformers.PreTrainedModel): Model instance to vet.

Raises:
ValueError: If *model* is not a ``PreTrainedModel``.
"""
import transformers

version = LooseVersion(transformers.__version__)
Expand Down Expand Up @@ -80,11 +99,26 @@ def validate(cls, model):
raise ValueError(f"The model must be an instance of PreTrainedModel, found {type(model)}")

def setup_config(self):
"""
Enable NV-kernelised linear, matmul and SDPA ops on top of the
base recipe’s debug configuration.

Returns:
dict[str, Any]: Thunder config dictionary augmented with
``nv_enable_*`` flags.
"""
config = super().setup_config()
config.update(nv_enable_linear=True, nv_enable_matmul=True, nv_enable_sdpa=True)
return config

def setup_lookasides(self):
"""
Swap out the warning helper when running under
the non Thunder-FX interpreter.

Returns:
list[thunder.core.recipe.Lookaside] | None
"""
if self.interpreter == thunder.core.recipe.Interpreter.THUNDER_FX:
return None

Expand All @@ -98,10 +132,28 @@ def setup_lookasides(self):
return [warn_lookaside]

def setup_transforms(self):
"""
Prepend the ``InplaceIndexCopyTransform`` to the default
transform list.

Returns:
list[thunder.Transform]: transform list.
"""
transforms = super().setup_transforms()
return [self.inplace_index_copy_transform] + transforms

def apply(self, model):
"""
Apply the recipe (compile the model) and patch ``generate`` / ``_sample``
so they work after tracing.

Args:
model (transformers.PreTrainedModel): The model to compile.

Returns:
transformers.PreTrainedModel: Thunder-compiled model ready
for inference.
"""
thunder_model = super().apply(model)

if getattr(thunder_model, "generate", None):
Expand Down