Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
clean up ruff
Signed-off-by: Yuhan Li <51736452+liyuhannnnn@users.noreply.github.com>
  • Loading branch information
liyuhannnnn committed Dec 1, 2025
commit e78d1f5d5b7e1cc1f6820fda97cdaffcf6504c6d
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -943,8 +943,6 @@ common-files: &common_files |
tests/scripts/iteration_log_parser.py |
tests/scripts/perf-sanity/parse_benchmark_results.py |
tests/scripts/perf-sanity/run_benchmark_serve.py |
tests/scripts/cute_dsl_kernels/test_dense_blockscaled_gemm_persistent.py |
tests/scripts/cute_dsl_kernels/testing.py |
tests/unittest/_torch/attention/sparse/test_dsa_indexer.py |
tests/unittest/_torch/attention/sparse/test_flash_mla.py |
tests/unittest/_torch/attention/sparse/test_rocketkv.py |
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -983,8 +983,6 @@ exclude = [
"tests/scripts/iteration_log_parser.py",
"tests/scripts/perf-sanity/parse_benchmark_results.py",
"tests/scripts/perf-sanity/run_benchmark_serve.py",
"tests/scripts/cute_dsl_kernels/test_dense_blockscaled_gemm_persistent.py",
"tests/scripts/cute_dsl_kernels/testing.py",
"tests/unittest/_torch/attention/sparse/test_dsa_indexer.py",
"tests/unittest/_torch/attention/sparse/test_flash_mla.py",
"tests/unittest/_torch/attention/sparse/test_rocketkv.py",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,11 @@
from cutlass.cute.runtime import from_dlpack

try:
from tensorrt_llm._torch.cute_dsl_kernels.blackwell import \
dense_blockscaled_gemm_persistent as kernel_module
from tensorrt_llm._torch.cute_dsl_kernels.blackwell import (
dense_blockscaled_gemm_persistent as kernel_module,
)
except (ModuleNotFoundError, ImportError):
sys.path.insert(
0,
str(Path(__file__).parents[3] / "tensorrt_llm/_torch/cute_dsl_kernels"))
sys.path.insert(0, str(Path(__file__).parents[3] / "tensorrt_llm/_torch/cute_dsl_kernels"))
from blackwell import dense_blockscaled_gemm_persistent as kernel_module

Sm100BlockScaledPersistentDenseGemmKernel = kernel_module.Sm100BlockScaledPersistentDenseGemmKernel
Expand Down Expand Up @@ -130,16 +129,12 @@ def run(
RuntimeError: If no CUDA-capable GPU is available.
TypeError: If the configuration is not supported by the kernel.
"""
print(f"Running Sm100 Persistent Dense BlockScaled GEMM test with:")
print("Running Sm100 Persistent Dense BlockScaled GEMM test with:")
print(f"mnkl: {mnkl}")
print(
f"AB dtype: {ab_dtype}, SF dtype: {sf_dtype}, SF Vec size: {sf_vec_size}"
)
print(f"AB dtype: {ab_dtype}, SF dtype: {sf_dtype}, SF Vec size: {sf_vec_size}")
print(f"C dtype: {c_dtype}")
print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}")
print(
f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}"
)
print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}")
print(f"Use prefetch: {'True' if use_prefetch else 'False'}")
print(f"Tolerance: {tolerance}")
print(f"Warmup iterations: {warmup_iterations}")
Expand All @@ -149,26 +144,28 @@ def run(
print(f"Use CUPTI: {'True' if use_cupti else 'False'}")

# Unpack parameters
m, n, k, l = mnkl
m, n, k, batch = mnkl

# Skip unsupported testcase
if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement(
ab_dtype,
sf_dtype,
sf_vec_size,
c_dtype,
mma_tiler_mn,
cluster_shape_mn,
m,
n,
k,
l,
a_major,
b_major,
c_major,
ab_dtype,
sf_dtype,
sf_vec_size,
c_dtype,
mma_tiler_mn,
cluster_shape_mn,
m,
n,
k,
batch,
a_major,
b_major,
c_major,
):
raise TypeError(
f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}"
f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, "
f"{mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {batch}, "
f"{a_major}, {b_major}, {c_major}"
)

if not torch.cuda.is_available():
Expand All @@ -177,22 +174,19 @@ def run(
torch.manual_seed(1111)

# Create tensor A/B/C
a_ref = cutlass_torch.matrix(l, m, k, a_major == "m", cutlass.Float32)
b_ref = cutlass_torch.matrix(l, n, k, b_major == "n", cutlass.Float32)
c_ref = cutlass_torch.matrix(l, m, n, c_major == "m", cutlass.Float32)

a_tensor, a_torch = cutlass_torch.cute_tensor_like(a_ref,
ab_dtype,
is_dynamic_layout=True,
assumed_align=16)
b_tensor, b_torch = cutlass_torch.cute_tensor_like(b_ref,
ab_dtype,
is_dynamic_layout=True,
assumed_align=16)
c_tensor, c_torch = cutlass_torch.cute_tensor_like(c_ref,
c_dtype,
is_dynamic_layout=True,
assumed_align=16)
a_ref = cutlass_torch.matrix(batch, m, k, a_major == "m", cutlass.Float32)
b_ref = cutlass_torch.matrix(batch, n, k, b_major == "n", cutlass.Float32)
c_ref = cutlass_torch.matrix(batch, m, n, c_major == "m", cutlass.Float32)

a_tensor, a_torch = cutlass_torch.cute_tensor_like(
a_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
b_tensor, b_torch = cutlass_torch.cute_tensor_like(
b_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
c_tensor, c_torch = cutlass_torch.cute_tensor_like(
c_ref, c_dtype, is_dynamic_layout=True, assumed_align=16
)

# Mark tensor to be byte aligned
a_tensor.mark_compact_shape_dynamic(
Expand All @@ -212,18 +206,17 @@ def run(
)

# Create scale factor tensor SFA/SFB
def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype):

def create_scale_factor_tensor(batch, mn, k, sf_vec_size, dtype):
def ceil_div(a, b):
return (a + b - 1) // b

sf_k = ceil_div(k, sf_vec_size)
ref_shape = (l, mn, sf_k)
ref_shape = (batch, mn, sf_k)

atom_m = (32, 4)
atom_k = 4
mma_shape = (
l,
batch,
ceil_div(mn, atom_m[0] * atom_m[1]),
ceil_div(sf_k, atom_k),
atom_m[0],
Expand Down Expand Up @@ -266,9 +259,13 @@ def ceil_div(a, b):
cute_f32_torch_tensor = cute_f32_torch_tensor_cpu.cuda()

# reshape makes memory contiguous
ref_f32_torch_tensor_cpu = (ref_f32_torch_tensor_cpu.permute(
2, 0, 1).unsqueeze(-1).expand(l, mn, sf_k, sf_vec_size).reshape(
l, mn, sf_k * sf_vec_size).permute(*ref_permute_order))
ref_f32_torch_tensor_cpu = (
ref_f32_torch_tensor_cpu.permute(2, 0, 1)
.unsqueeze(-1)
.expand(batch, mn, sf_k, sf_vec_size)
.reshape(batch, mn, sf_k * sf_vec_size)
.permute(*ref_permute_order)
)
# prune to mkl for reference check.
ref_f32_torch_tensor_cpu = ref_f32_torch_tensor_cpu[:, :k, :]

Expand All @@ -289,10 +286,8 @@ def ceil_div(a, b):
)
return ref_f32_torch_tensor_cpu, cute_tensor, cute_torch_tensor

sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor(
l, m, k, sf_vec_size, sf_dtype)
sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor(
l, n, k, sf_vec_size, sf_dtype)
sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor(batch, m, k, sf_vec_size, sf_dtype)
sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor(batch, n, k, sf_vec_size, sf_dtype)

# Configure gemm kernel
gemm = Sm100BlockScaledPersistentDenseGemmKernel(
Expand All @@ -305,7 +300,8 @@ def ceil_div(a, b):
# Compute max active clusters on current device
hardware_info = cutlass.utils.HardwareInfo()
max_active_clusters = hardware_info.get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1])
cluster_shape_mn[0] * cluster_shape_mn[1]
)

# Initialize Stream
current_stream = cutlass_torch.default_stream()
Expand All @@ -321,14 +317,13 @@ def ceil_div(a, b):
1.0, # alpha
max_active_clusters,
current_stream,
options=f"--opt-level 2",
options="--opt-level 2",
)

# Compute reference result
if not skip_ref_check:
# Execute kernel once for reference checking
compiled_gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, 1.0,
current_stream)
compiled_gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, 1.0, current_stream)
print("Verifying results...")
res_a = torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref)
res_b = torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref)
Expand All @@ -339,41 +334,37 @@ def ceil_div(a, b):
cute.testing.convert(
c_tensor,
from_dlpack(c_ref_device, assumed_align=16).mark_layout_dynamic(
leading_dim=(1 if c_major == "n" else 0)),
leading_dim=(1 if c_major == "n" else 0)
),
)
c_ref = c_ref_device.cpu()

if c_dtype in (cutlass.Float32, cutlass.Float16, cutlass.BFloat16):
torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02)
elif c_dtype in (cutlass.Float8E5M2, cutlass.Float8E4M3FN):
# Convert ref : f32 -> f8 -> f32
ref_f8_ = torch.empty(*(l, m, n), dtype=torch.uint8,
device="cuda").permute(1, 2, 0)
ref_f8 = from_dlpack(
ref_f8_, assumed_align=16).mark_layout_dynamic(leading_dim=1)
ref_f8_ = torch.empty(*(batch, m, n), dtype=torch.uint8, device="cuda").permute(1, 2, 0)
ref_f8 = from_dlpack(ref_f8_, assumed_align=16).mark_layout_dynamic(leading_dim=1)
ref_f8.element_type = c_dtype
ref_device = ref.permute(2, 0, 1).contiguous().permute(1, 2,
0).cuda()
ref_tensor = from_dlpack(
ref_device, assumed_align=16).mark_layout_dynamic(leading_dim=1)
ref_device = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0).cuda()
ref_tensor = from_dlpack(ref_device, assumed_align=16).mark_layout_dynamic(
leading_dim=1
)
cute.testing.convert(ref_tensor, ref_f8)
cute.testing.convert(ref_f8, ref_tensor)
ref = ref_device.cpu()
torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02)

def generate_tensors():
a_tensor, _ = cutlass_torch.cute_tensor_like(a_ref,
ab_dtype,
is_dynamic_layout=True,
assumed_align=16)
b_tensor, _ = cutlass_torch.cute_tensor_like(b_ref,
ab_dtype,
is_dynamic_layout=True,
assumed_align=16)
c_tensor, _ = cutlass_torch.cute_tensor_like(c_ref,
c_dtype,
is_dynamic_layout=True,
assumed_align=16)
a_tensor, _ = cutlass_torch.cute_tensor_like(
a_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
b_tensor, _ = cutlass_torch.cute_tensor_like(
b_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
c_tensor, _ = cutlass_torch.cute_tensor_like(
c_ref, c_dtype, is_dynamic_layout=True, assumed_align=16
)

# Mark tensor to be byte aligned
a_tensor.mark_compact_shape_dynamic(
Expand All @@ -392,23 +383,24 @@ def generate_tensors():
divisibility=2 if c_dtype == cutlass.Float4E2M1FN else 1,
)

_, sfa_tensor, _ = create_scale_factor_tensor(l, m, k, sf_vec_size,
sf_dtype)
_, sfb_tensor, _ = create_scale_factor_tensor(l, n, k, sf_vec_size,
sf_dtype)
return cute.testing.JitArguments(a_tensor, b_tensor, sfa_tensor,
sfb_tensor, c_tensor, 1.0,
current_stream)
_, sfa_tensor, _ = create_scale_factor_tensor(batch, m, k, sf_vec_size, sf_dtype)
_, sfb_tensor, _ = create_scale_factor_tensor(batch, n, k, sf_vec_size, sf_dtype)
return cute.testing.JitArguments(
a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, 1.0, current_stream
)

workspace_count = 1
if use_cold_l2:
one_workspace_bytes = (a_torch.numel() * a_torch.element_size() +
b_torch.numel() * b_torch.element_size() +
sfa_torch.numel() * sfa_torch.element_size() +
sfb_torch.numel() * sfb_torch.element_size() +
c_torch.numel() * c_torch.element_size())
one_workspace_bytes = (
a_torch.numel() * a_torch.element_size()
+ b_torch.numel() * b_torch.element_size()
+ sfa_torch.numel() * sfa_torch.element_size()
+ sfb_torch.numel() * sfb_torch.element_size()
+ c_torch.numel() * c_torch.element_size()
)
workspace_count = cute.testing.get_workspace_count(
one_workspace_bytes, warmup_iterations, iterations)
one_workspace_bytes, warmup_iterations, iterations
)

exec_time = benchmark(
compiled_gemm,
Expand All @@ -429,12 +421,10 @@ def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
try:
return tuple(int(x.strip()) for x in s.split(","))
except ValueError:
raise argparse.ArgumentTypeError(
"Invalid format. Expected comma-separated integers.")
raise argparse.ArgumentTypeError("Invalid format. Expected comma-separated integers.")

parser = argparse.ArgumentParser(
description=
"Functionality and Performance Test for Sm100 Dense Persistent BlockScaled GEMM."
description="Functionality and Performance Test for Sm100 Dense Persistent BlockScaled GEMM."
)

parser.add_argument(
Expand All @@ -455,16 +445,10 @@ def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
default=(1, 4),
help="Cluster shape (comma-separated)",
)
parser.add_argument("--ab_dtype",
type=cutlass.dtype,
default=cutlass.Float4E2M1FN)
parser.add_argument("--sf_dtype",
type=cutlass.dtype,
default=cutlass.Float8E8M0FNU)
parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float4E2M1FN)
parser.add_argument("--sf_dtype", type=cutlass.dtype, default=cutlass.Float8E8M0FNU)
parser.add_argument("--sf_vec_size", type=int, default=16)
parser.add_argument("--c_dtype",
type=cutlass.dtype,
default=cutlass.Float16)
parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float16)
parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k")
parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k")
parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n")
Expand All @@ -474,23 +458,15 @@ def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
default=False,
help="Enable TMA prefetch for both A and B matrices (default: False)",
)
parser.add_argument("--tolerance",
type=float,
default=1e-01,
help="Tolerance for validation")
parser.add_argument("--warmup_iterations",
type=int,
default=0,
help="Warmup iterations")
parser.add_argument("--tolerance", type=float, default=1e-01, help="Tolerance for validation")
parser.add_argument("--warmup_iterations", type=int, default=0, help="Warmup iterations")
parser.add_argument(
"--iterations",
type=int,
default=1,
help="Number of iterations to run the kernel",
)
parser.add_argument("--skip_ref_check",
action="store_true",
help="Skip reference checking")
parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
parser.add_argument(
"--use_cold_l2",
action="store_true",
Expand Down
Loading