[Kernels] add triton contiguous groupgemm#1154
Conversation
|
GPU test failure is not related to this PR: |
kwen2501
left a comment
There was a problem hiding this comment.
The integration code looks good to me.
Didn't look at the kernel closely. Did you check if the generated results are correct?
If correct, approving to unblock.
thanks and yes, verified results via unit tests (variety of shapes) as well as generate/inference producing valid output. |
|
Triton v3.4.0 doesn't support fill_1d_tma_descriptor and fill_2d_tma_descriptor, only support fill_tma_descriptor. Do you have any plan to refactor the code to support TMA in new Triton? |
|
Hi @yuan-luo - I'm not working in Triton anymore (hence didn't realize they changed the descriptors) but I think @AdnanHoque can quickly make the update. |
This PR adds a triton based contiguous group gemm and integrates it with deepseek. This group gemm has full backwards support. Testing: a - have verified cg group gemm works for deepseek inference. b - have added and run unit tests for both forward and backward. (The perf test is against manual looping). ~~~ ===== Testing Backward Pass: DeepSeek Shapes ===== Testing shape: M=512, K=128, N=128, group_size=128 Testing shape: M=1,024, K=1,024, N=1,024, group_size=128 . ===== Testing Forward Pass: DeepSeek Shapes ===== Testing shape: M=4,096, K=1,024, N=1,024, group_size=128 Testing shape: M=4,096, K=4,096, N=4,096, group_size=128 Testing shape: M=2,048, K=4,096, N=7,168, group_size=128 Testing shape: M=2,048, K=7,168, N=2,048, group_size=128 . ===== Benchmarking Backward Pass with DeepSeek Shapes ===== Benchmarking shape: M=512, K=128, N=128, group_size=128 Benchmarking shape: M=1,024, K=1,024, N=1,024, group_size=128 DeepSeek Backward Performance Summary: Shape CG-GEMM (ms) PyTorch (ms) Speedup TFLOPS -------------------------------------------------------------------------------- M=512, K=128, N=128 0.305 0.837 2.74 0.11 M=1024, K=1024, N=1024 0.304 1.293 4.25 14.12 . ===== Benchmarking Forward Pass with DeepSeek Shapes ===== Benchmarking shape: M=2,048, K=1,024, N=1,024, group_size=128 Benchmarking shape: M=4,096, K=4,096, N=4,096, group_size=128 DeepSeek Forward Performance Summary: Shape CG-GEMM (ms) PyTorch (ms) Speedup TFLOPS -------------------------------------------------------------------------------- M=2048, K=1024, N=1024 0.050 0.617 12.46 86.69 M=4096, K=4096, N=4096 0.380 1.569 4.13 361.60 . ---------------------------------------------------------------------- Ran 4 tests in 6.268s OK ~~~
This PR adds a triton based contiguous group gemm and integrates it with deepseek. This group gemm has full backwards support. Testing: a - have verified cg group gemm works for deepseek inference. b - have added and run unit tests for both forward and backward. (The perf test is against manual looping). ~~~ ===== Testing Backward Pass: DeepSeek Shapes ===== Testing shape: M=512, K=128, N=128, group_size=128 Testing shape: M=1,024, K=1,024, N=1,024, group_size=128 . ===== Testing Forward Pass: DeepSeek Shapes ===== Testing shape: M=4,096, K=1,024, N=1,024, group_size=128 Testing shape: M=4,096, K=4,096, N=4,096, group_size=128 Testing shape: M=2,048, K=4,096, N=7,168, group_size=128 Testing shape: M=2,048, K=7,168, N=2,048, group_size=128 . ===== Benchmarking Backward Pass with DeepSeek Shapes ===== Benchmarking shape: M=512, K=128, N=128, group_size=128 Benchmarking shape: M=1,024, K=1,024, N=1,024, group_size=128 DeepSeek Backward Performance Summary: Shape CG-GEMM (ms) PyTorch (ms) Speedup TFLOPS -------------------------------------------------------------------------------- M=512, K=128, N=128 0.305 0.837 2.74 0.11 M=1024, K=1024, N=1024 0.304 1.293 4.25 14.12 . ===== Benchmarking Forward Pass with DeepSeek Shapes ===== Benchmarking shape: M=2,048, K=1,024, N=1,024, group_size=128 Benchmarking shape: M=4,096, K=4,096, N=4,096, group_size=128 DeepSeek Forward Performance Summary: Shape CG-GEMM (ms) PyTorch (ms) Speedup TFLOPS -------------------------------------------------------------------------------- M=2048, K=1024, N=1024 0.050 0.617 12.46 86.69 M=4096, K=4096, N=4096 0.380 1.569 4.13 361.60 . ---------------------------------------------------------------------- Ran 4 tests in 6.268s OK ~~~
This PR adds a triton based contiguous group gemm and integrates it with deepseek.
This group gemm has full backwards support.
Testing:
a - have verified cg group gemm works for deepseek inference.
b - have added and run unit tests for both forward and backward. (The perf test is against manual looping).