Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
reorganize cudnn_sdpa and remove unnecessary
imports
  • Loading branch information
KaelanDt committed Jun 13, 2025
commit 3568fe6e1c1872324cbe5e9a7b96a8a8503b6022
64 changes: 40 additions & 24 deletions thunder/executors/cudnn_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@

from thunder.extend import OperatorExecutor, register_executor

import cudnn

cudnn_backend_version = cudnn.backend_version()

cudnn_ex: OperatorExecutor = OperatorExecutor("cudnn", version=cudnn_backend_version)
register_executor(cudnn_ex)

Comment thread
KaelanDt marked this conversation as resolved.

# Cache already built cudnn graphs to save on runtime compilation overhead
class CudnnexLRUCache(OrderedDict):
def __init__(self, maxlen, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -35,11 +43,25 @@ def __setitem__(self, key, value):


_cudnnex_cache = CudnnexLRUCache(maxlen=1024)
# Mapping from device to cudnn handles
_device_to_cudnn_handle = {}


# This function creates a new handle for the device that cudnn should
# run its kernels on. As the suggested approach by cudnn is to make a few handles
# as possible, this function caches these per-device handles.
def _get_cudnn_handle(query_device):
import cudnn
handle = _device_to_cudnn_handle.get(query_device, None)
if handle is None:
with torch.cuda.device(query_device):
handle = cudnn.create_handle()
_device_to_cudnn_handle[query_device] = handle

# Make sure the user stream is set on the handle
# Fetch the current user stream and pass the data pointer to set_stream API
cudnn.set_stream(handle=handle, stream=torch.cuda.current_stream(device=query_device).cuda_stream)

return handle

handle = _device_to_cudnn_handle.get(query_device, None)
if handle is None:
Expand All @@ -53,8 +75,6 @@ def _get_cudnn_handle(query_device):
def _make_cudnn_sdpa_forward_graph(
query, key, value, attn_mask, dropout_p, is_causal, query_stride, key_stride, value_stride
):
import cudnn

graph = cudnn.pygraph(
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
Expand Down Expand Up @@ -353,7 +373,7 @@ def _cudnn_sdpa_checker(
key_stride,
value_stride, # Use the same strides as inputs for their respective grads
)
# If cudnn can't support the graph, return false

# Please turn on cudnn API logging for helpful messages that mention why the graph is not supported.
# For cudnn backend logging, refer https://docs.nvidia.com/deeplearning/cudnn/latest/reference/troubleshooting.html
# For cudnn frontend logging, refer https://github.com/NVIDIA/cudnn-frontend?tab=readme-ov-file#debugging
Expand All @@ -367,6 +387,14 @@ def _cudnn_sdpa_checker(
return True


Comment thread
KaelanDt marked this conversation as resolved.
cudnn_sdpa_fwd = cudnn_ex.register_operator(
"cudnn_sdpa_fwd",
meta=_cudnn_sdpa_forward_meta,
fn=_cudnn_sdpa_fwd_impl,
tags=(OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD,),
)


def _make_cudnn_sdpa_backward_graph(
query,
key,
Expand All @@ -381,8 +409,6 @@ def _make_cudnn_sdpa_backward_graph(
grad_key_stride,
grad_value_stride,
):
import cudnn

b, h, s_q, _ = query.shape
_, _, _, d_v = value.shape

Expand Down Expand Up @@ -659,6 +685,13 @@ def _cudnn_sdpa_bwd_impl(
return grads


Comment thread
KaelanDt marked this conversation as resolved.
cudnn_sdpa_bwd = cudnn_ex.register_operator(
"cudnn_sdpa_bwd",
meta=_cudnn_sdpa_bwd_meta,
fn=_cudnn_sdpa_bwd_impl,
)


def _cudnn_sdpa_fwd_wrapper(
query: TensorProxy,
key: TensorProxy,
Expand Down Expand Up @@ -733,24 +766,7 @@ def _cudnn_sdpa_bwd_wrapper(
return primal


import cudnn

cudnn_ex = OperatorExecutor("cudnn", version=cudnn.backend_version())
register_executor(cudnn_ex)
cudnn_sdpa_fwd = cudnn_ex.register_operator(
"cudnn_sdpa_fwd",
meta=_cudnn_sdpa_forward_meta,
fn=_cudnn_sdpa_fwd_impl,
tags=(OpTags.DONT_AUTO_RECOMPUTE_IN_BACKWARD,),
)

cudnn_sdpa_bwd = cudnn_ex.register_operator(
"cudnn_sdpa_bwd",
meta=_cudnn_sdpa_bwd_meta,
fn=_cudnn_sdpa_bwd_impl,
)


# Registers the implementation for torch.nn.functional.scaled_dot_product_attention
cudnn_ex.register_implementation(
ltorch.scaled_dot_product_attention,
checker=_cudnn_sdpa_checker,
Expand Down
Loading