-
Notifications
You must be signed in to change notification settings - Fork 211
implement group gemm for contiguous case #136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
|
||
| if (args.common.is_tma_multicast_valid) { | ||
| if (cluster_ctarank() == 0) { | ||
| tma::cluster::expect(args.inputs_cluster_arrived, 0, args.input.b); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@benjaminfspector @DanFu09 @simran-arora Currently, the non-multicast implementation can reach 80% of the performance of deepgeem, but there is a performance regression when using TMA multicast. I tested the deepgemm TMA multicast path, which can achieve a 10%-20% performance improvement. It may have an issue with my code. Could you help review it? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DanFu09 @benjaminfspector @simran-arora @StuartSul Any suggestions? Thanks!
|
See this: https://github.com/HazyResearch/ThunderKittens/pull/98/files You probably want to use larger wgmma ops. |
Yes, I tried it, it can get a better performance using a bigger N block size(deepgemm uses 192 block size), but I want to use TMA multicast feature to reduce the load of gloable memory to share memory. I tested the deepgemm code, the TMA multicast path can get 10%~20% performance improvement. Thanks! |
84040e3 to
0683cc0
Compare
|
Current performance is: Testing grouped contiguous GEMM for deepgemm(block_m=block_k=128,block_n=192):
> Perf (num_groups= 4, expected_m_per_group=8192, n=4096, k=7168): 1408 us | throughput: 1367 TFLOPS, 441 GB/s
> Perf (num_groups= 4, expected_m_per_group=8192, n=7168, k=2048): 749 us | throughput: 1285 TFLOPS, 796 GB/s
> Perf (num_groups= 8, expected_m_per_group=4096, n=4096, k=7168): 1411 us | throughput: 1364 TFLOPS, 523 GB/s
> Perf (num_groups= 8, expected_m_per_group=4096, n=7168, k=2048): 752 us | throughput: 1279 TFLOPS, 870 GB/s
> Perf (num_groups=32, expected_m_per_group= 256, n=4096, k=7168): 465 us | throughput: 1036 TFLOPS, 2294 GB/s
> Perf (num_groups=32, expected_m_per_group= 256, n=7168, k=2048): 244 us | throughput: 985 TFLOPS, 2475 GB/s
Testing grouped contiguous GEMM for deepgemm(block_m=block_n=block_k=128):
> Perf (num_groups= 4, expected_m_per_group=8192, n=4096, k=7168): 1505 us | throughput: 1279 TFLOPS, 412 GB/s
> Perf (num_groups= 4, expected_m_per_group=8192, n=7168, k=2048): 796 us | throughput: 1209 TFLOPS, 748 GB/s
> Perf (num_groups= 8, expected_m_per_group=4096, n=4096, k=7168): 1504 us | throughput: 1279 TFLOPS, 491 GB/s
> Perf (num_groups= 8, expected_m_per_group=4096, n=7168, k=2048): 798 us | throughput: 1206 TFLOPS, 820 GB/s
> Perf (num_groups=32, expected_m_per_group= 256, n=4096, k=7168): 482 us | throughput: 998 TFLOPS, 2211 GB/s
> Perf (num_groups=32, expected_m_per_group= 256, n=7168, k=2048): 256 us | throughput: 939 TFLOPS, 2358 GB/s
Testing grouped contiguous GEMM for tk(block_m=block_n=block_k=128)):
> Perf (num_groups= 4, expected_m_per_group=8192, n=4096, k=7168): 2150 us | throughput: 895 TFLOPS, 289 GB/s
> Perf (num_groups= 4, expected_m_per_group=8192, n=7168, k=2048): 1137 us | throughput: 846 TFLOPS, 524 GB/s
> Perf (num_groups= 8, expected_m_per_group=4096, n=4096, k=7168): 2148 us | throughput: 896 TFLOPS, 344 GB/s
> Perf (num_groups= 8, expected_m_per_group=4096, n=7168, k=2048): 1136 us | throughput: 847 TFLOPS, 576 GB/s
> Perf (num_groups=32, expected_m_per_group= 256, n=4096, k=7168): 595 us | throughput: 808 TFLOPS, 1790 GB/s
> Perf (num_groups=32, expected_m_per_group= 256, n=7168, k=2048): 315 us | throughput: 763 TFLOPS, 1916 GB/sNote that: the current TK implementation using multicast has about 15% performance regression compared to a non-multicast path. |
This pull request implements FP8 group GEMM for the contiguous case. The masked case will be added at the next step.