Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 14 additions & 19 deletions thunder/executors/cudnn_layernormex.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,14 @@
import torch
import numpy as np

from lightning_utilities.core.imports import package_available

cudnn: None | Any = None
cudnn_backend_version: None | Any = None
if package_available("cudnn"):
import cudnn

cudnn_backend_version = cudnn.backend_version()


# WARNING: cudnn layernorm executor is experimental. Tests that use cudnn might fail.
from dataclasses import dataclass
from functools import lru_cache


from thunder.executors.cudnnex import torch_to_cudnn_dtype
from thunder.executors.cudnnex import cudnn_available, torch_to_cudnn_dtype
from thunder.extend import OperatorExecutor


@dataclass(frozen=True)
Expand All @@ -44,12 +36,6 @@ def wrapper(*args, **kwargs):
return wrapper


from thunder.extend import OperatorExecutor, register_executor

cudnn_layernorm_ex: OperatorExecutor = OperatorExecutor("cudnn_layernorm", version=cudnn_backend_version)
register_executor(cudnn_layernorm_ex)


@make_cacheable_cudnn_graph_inputs
@lru_cache(maxsize=1024)
def _make_cudnn_layer_norm_graph(a_4d, weight_4d, bias_4d):
Expand Down Expand Up @@ -133,7 +119,16 @@ def layer_norm_checker(a, normalized_shape, weight=None, bias=None, eps=1e-5):
return True


import thunder.torch as ltorch
cudnn_layernorm_ex: None | OperatorExecutor = None

if cudnn_available():
from thunder.extend import register_executor
import cudnn

cudnn_layernorm_ex: OperatorExecutor = OperatorExecutor("cudnn_layernorm", version=cudnn.backend_version())
register_executor(cudnn_layernorm_ex)

import thunder.torch as ltorch

layer_norm = cudnn_layernorm_ex.register_operator("cudnn_layernorm", like=ltorch.layer_norm, fn=layer_norm_impl)
cudnn_layernorm_ex.register_implementation(ltorch.layer_norm, layer_norm, checker=layer_norm_checker)
layer_norm = cudnn_layernorm_ex.register_operator("cudnn_layernorm", like=ltorch.layer_norm, fn=layer_norm_impl)
cudnn_layernorm_ex.register_implementation(ltorch.layer_norm, layer_norm, checker=layer_norm_checker)
Loading
Loading