Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
move registration to cudnn_sdpa
  • Loading branch information
KaelanDt committed Jun 13, 2025
commit 7e885bb92d94214feaef4ee349bb2871c3f969b7
26 changes: 26 additions & 0 deletions thunder/executors/cudnn_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,3 +732,29 @@ def _cudnn_sdpa_bwd_wrapper(
put_grads((query, key, value), (grad_query, grad_key, grad_value))

return primal


cudnn_ex = OperatorExecutor("cudnn", version=cudnn.backend_version())
register_executor(cudnn_ex)
cudnn_sdpa_fwd = cudnn_ex.register_operator(
"cudnn_sdpa_fwd",
meta=sdpa_impl._cudnn_sdpa_forward_meta,
fn=sdpa_impl._cudnn_sdpa_fwd_impl,
tags=(sdpa_impl.OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD,),
)

cudnn_sdpa_bwd = cudnn_ex.register_operator(
"cudnn_sdpa_bwd",
meta=sdpa_impl._cudnn_sdpa_bwd_meta,
fn=sdpa_impl._cudnn_sdpa_bwd_impl,
)

sdpa_impl.cudnn_sdpa_fwd = cudnn_sdpa_fwd
sdpa_impl.cudnn_sdpa_bwd = cudnn_sdpa_bwd

cudnn_ex.register_implementation(
sdpa_impl.ltorch.scaled_dot_product_attention,
checker=sdpa_impl._cudnn_sdpa_checker,
execution_transform=sdpa_impl._cudnn_sdpa_fwd_wrapper,
grad_transform=sdpa_impl._cudnn_sdpa_bwd_wrapper,
)
36 changes: 6 additions & 30 deletions thunder/executors/cudnnex.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any
from collections.abc import Callable

from lightning_utilities.core.imports import package_available
from looseversion import LooseVersion
from thunder.extend import OperatorExecutor, register_executor

__all__ = ["cudnn_version", "required_cudnn_version", "cudnn_available", "cudnn_ex"]
__all__ = ["cudnn_version", "required_cudnn_version", "cudnn_available", "cudnn_ex", "torch_to_cudnn_dtype"]


#
Expand Down Expand Up @@ -50,38 +50,14 @@ def cudnn_available() -> bool:


cudnn_ex: None | OperatorExecutor = None
cudnn_sdpa_fwd: Any = None
cudnn_sdpa_bwd: Any = None

torch_to_cudnn_dtype: None | Callable
cudnn = None

if cudnn_available():
import cudnn
import thunder.executors.cudnn_sdpa as sdpa_impl
import cudnn

sdpa_impl.cudnn = cudnn

cudnn_ex = OperatorExecutor("cudnn", version=cudnn.backend_version())
register_executor(cudnn_ex)
cudnn_sdpa_fwd = cudnn_ex.register_operator(
"cudnn_sdpa_fwd",
meta=sdpa_impl._cudnn_sdpa_forward_meta,
fn=sdpa_impl._cudnn_sdpa_fwd_impl,
tags=(sdpa_impl.OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD,),
)

cudnn_sdpa_bwd = cudnn_ex.register_operator(
"cudnn_sdpa_bwd",
meta=sdpa_impl._cudnn_sdpa_bwd_meta,
fn=sdpa_impl._cudnn_sdpa_bwd_impl,
)

sdpa_impl.cudnn_sdpa_fwd = cudnn_sdpa_fwd
sdpa_impl.cudnn_sdpa_bwd = cudnn_sdpa_bwd

cudnn_ex.register_implementation(
sdpa_impl.ltorch.scaled_dot_product_attention,
checker=sdpa_impl._cudnn_sdpa_checker,
execution_transform=sdpa_impl._cudnn_sdpa_fwd_wrapper,
grad_transform=sdpa_impl._cudnn_sdpa_bwd_wrapper,
)
torch_to_cudnn_dtype = sdpa_impl.torch_to_cudnn_dtype
cudnn_ex = sdpa_impl.cudnn_ex