Skip to content

[Kernels] add triton contiguous groupgemm#1154

Merged
lessw2020 merged 7 commits intomainfrom
lessw2020/add_contiguous_groupgemm
May 1, 2025
Merged

[Kernels] add triton contiguous groupgemm#1154
lessw2020 merged 7 commits intomainfrom
lessw2020/add_contiguous_groupgemm

Conversation

@lessw2020
Copy link
Contributor

@lessw2020 lessw2020 commented Apr 29, 2025

This PR adds a triton based contiguous group gemm and integrates it with deepseek.
This group gemm has full backwards support.

Testing:
a - have verified cg group gemm works for deepseek inference.
b - have added and run unit tests for both forward and backward. (The perf test is against manual looping).

===== Testing Backward Pass: DeepSeek Shapes   =====
Testing shape: M=512, K=128, N=128, group_size=128
Testing shape: M=1,024, K=1,024, N=1,024, group_size=128
.
===== Testing Forward Pass: DeepSeek Shapes =====
Testing shape: M=4,096, K=1,024, N=1,024, group_size=128
Testing shape: M=4,096, K=4,096, N=4,096, group_size=128
Testing shape: M=2,048, K=4,096, N=7,168, group_size=128
Testing shape: M=2,048, K=7,168, N=2,048, group_size=128
.
===== Benchmarking Backward Pass with DeepSeek Shapes   =====
Benchmarking shape: M=512, K=128, N=128, group_size=128
Benchmarking shape: M=1,024, K=1,024, N=1,024, group_size=128

DeepSeek Backward Performance Summary:
Shape                          CG-GEMM (ms)    PyTorch (ms)    Speedup    TFLOPS    
--------------------------------------------------------------------------------
M=512, K=128, N=128            0.305           0.837           2.74       0.11      
M=1024, K=1024, N=1024         0.304           1.293           4.25       14.12     
.
===== Benchmarking Forward Pass with DeepSeek Shapes =====
Benchmarking shape: M=2,048, K=1,024, N=1,024, group_size=128
Benchmarking shape: M=4,096, K=4,096, N=4,096, group_size=128

DeepSeek Forward Performance Summary:
Shape                          CG-GEMM (ms)    PyTorch (ms)    Speedup    TFLOPS    
--------------------------------------------------------------------------------
M=2048, K=1024, N=1024         0.050           0.617           12.46      86.69     
M=4096, K=4096, N=4096         0.380           1.569           4.13       361.60    
.
----------------------------------------------------------------------
Ran 4 tests in 6.268s

OK

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 29, 2025
@lessw2020 lessw2020 requested a review from kwen2501 April 29, 2025 17:59
@lessw2020 lessw2020 changed the title [WIP][Kernels] add triton contiguous groupgemm [Kernels] add triton contiguous groupgemm Apr 29, 2025
@lessw2020
Copy link
Contributor Author

GPU test failure is not related to this PR:
rank0]:[rank0]: RuntimeError: CUDA driver error: invalid device context

Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The integration code looks good to me.
Didn't look at the kernel closely. Did you check if the generated results are correct?
If correct, approving to unblock.

@lessw2020
Copy link
Contributor Author

The integration code looks good to me. Didn't look at the kernel closely. Did you check if the generated results are correct? If correct, approving to unblock.

thanks and yes, verified results via unit tests (variety of shapes) as well as generate/inference producing valid output.

@lessw2020 lessw2020 merged commit a32f0df into main May 1, 2025
5 of 6 checks passed
@lessw2020 lessw2020 deleted the lessw2020/add_contiguous_groupgemm branch May 1, 2025 21:45
@yuan-luo
Copy link

Triton v3.4.0 doesn't support fill_1d_tma_descriptor and fill_2d_tma_descriptor, only support fill_tma_descriptor. Do you have any plan to refactor the code to support TMA in new Triton?

/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/fused_moe_ibm/persistent_gg_bf16_tma.py:291: in grouped_gemm_persistent
    return _grouped_gemm_persistent(x, w, expert_indices)
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/fused_moe_ibm/persistent_gg_bf16_tma.py:196: in _grouped_gemm_persistent
    desc_helper = TmaDescriptorHelper(tma_size=tma_size)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <sglang.srt.layers.moe.fused_moe_ibm.tma_autotune.TmaDescriptorHelper object at 0x7f1975b901c0>, tma_size = 128

    def __init__(self, tma_size: int = 128):
        """Initialize the TMA descriptor helper.
    
        Args:
            tma_size: Size of the TMA descriptor in bytes
        """
        if not CudaUtils.verify_tma():
            raise RuntimeError(
                "TMA not supported on this device (requires Hopper or newer)"
            )
    
        self.tma_size = tma_size
>       self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
E       AttributeError: 'CudaUtils' object has no attribute 'fill_1d_tma_descriptor'. Did you mean: 'fill_tma_descriptor'?

@lessw2020
Copy link
Contributor Author

Hi @yuan-luo - I'm not working in Triton anymore (hence didn't realize they changed the descriptors) but I think @AdnanHoque can quickly make the update.

xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
This PR adds a triton based contiguous group gemm and integrates it with
deepseek.
This group gemm has full backwards support. 

Testing:
a - have verified cg group gemm works for deepseek inference. 
b - have added and run unit tests for both forward and backward. (The
perf test is against manual looping).

~~~
===== Testing Backward Pass: DeepSeek Shapes   =====
Testing shape: M=512, K=128, N=128, group_size=128
Testing shape: M=1,024, K=1,024, N=1,024, group_size=128
.
===== Testing Forward Pass: DeepSeek Shapes =====
Testing shape: M=4,096, K=1,024, N=1,024, group_size=128
Testing shape: M=4,096, K=4,096, N=4,096, group_size=128
Testing shape: M=2,048, K=4,096, N=7,168, group_size=128
Testing shape: M=2,048, K=7,168, N=2,048, group_size=128
.
===== Benchmarking Backward Pass with DeepSeek Shapes   =====
Benchmarking shape: M=512, K=128, N=128, group_size=128
Benchmarking shape: M=1,024, K=1,024, N=1,024, group_size=128

DeepSeek Backward Performance Summary:
Shape CG-GEMM (ms) PyTorch (ms) Speedup TFLOPS

--------------------------------------------------------------------------------
M=512, K=128, N=128 0.305 0.837 2.74 0.11
M=1024, K=1024, N=1024 0.304 1.293 4.25 14.12
.
===== Benchmarking Forward Pass with DeepSeek Shapes =====
Benchmarking shape: M=2,048, K=1,024, N=1,024, group_size=128
Benchmarking shape: M=4,096, K=4,096, N=4,096, group_size=128

DeepSeek Forward Performance Summary:
Shape CG-GEMM (ms) PyTorch (ms) Speedup TFLOPS

--------------------------------------------------------------------------------
M=2048, K=1024, N=1024 0.050 0.617 12.46 86.69
M=4096, K=4096, N=4096 0.380 1.569 4.13 361.60
.
----------------------------------------------------------------------
Ran 4 tests in 6.268s

OK
~~~
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
This PR adds a triton based contiguous group gemm and integrates it with
deepseek.
This group gemm has full backwards support. 

Testing:
a - have verified cg group gemm works for deepseek inference. 
b - have added and run unit tests for both forward and backward. (The
perf test is against manual looping).

~~~
===== Testing Backward Pass: DeepSeek Shapes   =====
Testing shape: M=512, K=128, N=128, group_size=128
Testing shape: M=1,024, K=1,024, N=1,024, group_size=128
.
===== Testing Forward Pass: DeepSeek Shapes =====
Testing shape: M=4,096, K=1,024, N=1,024, group_size=128
Testing shape: M=4,096, K=4,096, N=4,096, group_size=128
Testing shape: M=2,048, K=4,096, N=7,168, group_size=128
Testing shape: M=2,048, K=7,168, N=2,048, group_size=128
.
===== Benchmarking Backward Pass with DeepSeek Shapes   =====
Benchmarking shape: M=512, K=128, N=128, group_size=128
Benchmarking shape: M=1,024, K=1,024, N=1,024, group_size=128

DeepSeek Backward Performance Summary:
Shape CG-GEMM (ms) PyTorch (ms) Speedup TFLOPS

--------------------------------------------------------------------------------
M=512, K=128, N=128 0.305 0.837 2.74 0.11
M=1024, K=1024, N=1024 0.304 1.293 4.25 14.12
.
===== Benchmarking Forward Pass with DeepSeek Shapes =====
Benchmarking shape: M=2,048, K=1,024, N=1,024, group_size=128
Benchmarking shape: M=4,096, K=4,096, N=4,096, group_size=128

DeepSeek Forward Performance Summary:
Shape CG-GEMM (ms) PyTorch (ms) Speedup TFLOPS

--------------------------------------------------------------------------------
M=2048, K=1024, N=1024 0.050 0.617 12.46 86.69
M=4096, K=4096, N=4096 0.380 1.569 4.13 361.60
.
----------------------------------------------------------------------
Ran 4 tests in 6.268s

OK
~~~
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants