diff --git a/cpp/tensorrt_llm/thop/cublasScaledMM.cpp b/cpp/tensorrt_llm/thop/cublasScaledMM.cpp index d39b7b693fe..f4b2291e6c1 100644 --- a/cpp/tensorrt_llm/thop/cublasScaledMM.cpp +++ b/cpp/tensorrt_llm/thop/cublasScaledMM.cpp @@ -343,11 +343,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) m.def( "cublas_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scale_a, Tensor scale_b, Tensor? bias," " ScalarType? out_dtype, bool to_userbuffers=False) -> (Tensor out)"); - m.def( - "cublas_scaled_mm_out(Tensor mat_a, Tensor mat_b, Tensor scale_a, Tensor scale_b, Tensor? bias," - " int userbuffers_id, Tensor! out) -> (Tensor out)"); m.def("cublas_mm(Tensor mat_a, Tensor mat_b, Tensor? bias, ScalarType? out_dtype) -> (Tensor out)"); - m.def("cublas_mm_out(Tensor mat_a, Tensor mat_b, Tensor? bias, Tensor! out) -> (Tensor out)"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) diff --git a/tests/unittest/_torch/test_custom_ops.py b/tests/unittest/_torch/test_custom_ops.py new file mode 100644 index 00000000000..9a50468ae30 --- /dev/null +++ b/tests/unittest/_torch/test_custom_ops.py @@ -0,0 +1,106 @@ +import pytest +import torch +import torch._library.utils as library_utils + +import tensorrt_llm # noqa: F401 + + +def discover_namespace_ops(namespace: str, prefix: str = ""): + """Discover custom ops in a specific namespace.""" + # C++ custom ops are lazy loaded, cannot use torch.ops.x to discover all custom ops. + # Use schemas to discover instead. + ops_schemas = torch._C._jit_get_all_schemas() + ops = [] + + ns_prefix = f"{namespace}::{prefix}" + print("Discovering custom ops:") + for schema in ops_schemas: + if not schema.name.startswith(ns_prefix): + continue + op = library_utils.lookup_op(schema.name) + ops.append(op) + print(f" {op._name}") + return ops + + +def discover_custom_ops(namespaces): + """Discover all custom ops in the codebase.""" + discovered_ops = [] + for ns in namespaces: + ops = discover_namespace_ops(ns) + print(f"Total {len(ops)} custom ops in namespace {ns}") + discovered_ops.extend(ops) + return discovered_ops + + +@pytest.fixture(scope="module", autouse=True) +def custom_ops(): + """Discover custom ops in the codebase.""" + # "auto_deploy" custom ops are not checked here. + custom_op_namespaces = ("trtllm", ) + + return discover_custom_ops(custom_op_namespaces) + + +# Better to add OpInfo for each custom op, and use opcheck to test the custom ops. +# Currently OpInfo for custom ops are not available in the codebase. +# As a trade-off, only fake registration is checked. +def test_register_fake(custom_ops): + """Test custom operator fake impl registration.""" + + # Custom ops that are not required to have fake impl. + waivers = { + "trtllm::record_stream", + "trtllm::wait_event", + "trtllm::record_event", + "trtllm::set_stream", + } + + # TODO: add fake impl for these ops in follow-up PRs. + to_fix = { + "trtllm::lora_grouped_gemm", + "trtllm::mtp_relaxed_acceptance_op", + "trtllm::mtp_update_hidden_states_op", + "trtllm::mtp_prepare_drafter_inputs_op", + "trtllm::selective_scan", + "trtllm::reducescatter_list", + "trtllm::fp8_per_tensor_scale_moe_runner", + "trtllm::migrate_to_host_accessible", + "trtllm::mnnvl_moe_alltoallv_prepare_without_allgather", + "trtllm::mamba_conv1d", + "trtllm::llama4_moe_tp8ep1_min_latency", + "trtllm::llama4_fp8_fp8_gemm_swiglu", + "trtllm::llama4_fp8_bf16_gemm", + "trtllm::llama4_bf16_bf16_gemm", + "trtllm::fused_topk_softmax", + "trtllm::fp8_batched_quantize_1x128_permute102", + "trtllm::fp8_block_scaling_moe_gemm", + "trtllm::fp8_block_scaling_bmm_out", + "trtllm::fp8_block_scaling_bmm", + "trtllm::fp4_batched_quantize", + "trtllm::fp4_gemm_trtllmgen", + "trtllm::fp4_bmm", + "trtllm::merge_chunked_attention_for_mla", + "trtllm::cuda_scaled_mm", + "trtllm::initialize_static_lowprecision_buffers", + "trtllm::cutlass_scaled_mm", + "trtllm::fp8_per_tensor_scaling_tllmg_gemm", + "trtllm::load_chunked_kv_cache_for_mla", + "trtllm::load_paged_kv_cache_for_mla", + "trtllm::set_paged_kv_cache_for_mla", + "trtllm::set_chunked_kv_cache_for_mla", + "trtllm::mla_rope_append_paged_kv_assign_q", + "trtllm::fused_qk_norm_rope", + } + + ops_missing_fake_impl = [] + + for op in custom_ops: + if op._name in waivers or op._name in to_fix: + continue + if not library_utils.has_fake_kernel(op): + ops_missing_fake_impl.append(op) + + names = ", ".join(op._name for op in ops_missing_fake_impl) + assert len( + ops_missing_fake_impl) == 0, f"Custom ops missing fake impl: {names}"