Skip to content

Conversation

@XiaobingSuper
Copy link
Contributor

@XiaobingSuper XiaobingSuper commented Jul 22, 2025

This pull request implements FP8 group GEMM for the contiguous case. The masked case will be added at the next step.

@XiaobingSuper XiaobingSuper marked this pull request as draft July 22, 2025 07:26

if (args.common.is_tma_multicast_valid) {
if (cluster_ctarank() == 0) {
tma::cluster::expect(args.inputs_cluster_arrived, 0, args.input.b);
Copy link
Contributor Author

@XiaobingSuper XiaobingSuper Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benjaminfspector @DanFu09 @simran-arora Currently, the non-multicast implementation can reach 80% of the performance of deepgeem, but there is a performance regression when using TMA multicast. I tested the deepgemm TMA multicast path, which can achieve a 10%-20% performance improvement. It may have an issue with my code. Could you help review it? Thanks!

Copy link
Contributor Author

@XiaobingSuper XiaobingSuper Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DanFu09
Copy link
Contributor

DanFu09 commented Jul 23, 2025

See this: https://github.com/HazyResearch/ThunderKittens/pull/98/files

You probably want to use larger wgmma ops.

@XiaobingSuper
Copy link
Contributor Author

See this: https://github.com/HazyResearch/ThunderKittens/pull/98/files

You probably want to use larger wgmma ops.

Yes, I tried it, it can get a better performance using a bigger N block size(deepgemm uses 192 block size), but I want to use TMA multicast feature to reduce the load of gloable memory to share memory. I tested the deepgemm code, the TMA multicast path can get 10%~20% performance improvement. Thanks!

@XiaobingSuper XiaobingSuper marked this pull request as ready for review July 23, 2025 07:22
@XiaobingSuper
Copy link
Contributor Author

XiaobingSuper commented Jul 24, 2025

Current performance is:

Testing grouped contiguous GEMM for deepgemm(block_m=block_k=128,block_n=192):
 > Perf (num_groups= 4, expected_m_per_group=8192, n=4096, k=7168): 1408 us | throughput: 1367 TFLOPS,  441 GB/s
 > Perf (num_groups= 4, expected_m_per_group=8192, n=7168, k=2048):  749 us | throughput: 1285 TFLOPS,  796 GB/s
 > Perf (num_groups= 8, expected_m_per_group=4096, n=4096, k=7168): 1411 us | throughput: 1364 TFLOPS,  523 GB/s
 > Perf (num_groups= 8, expected_m_per_group=4096, n=7168, k=2048):  752 us | throughput: 1279 TFLOPS,  870 GB/s
 > Perf (num_groups=32, expected_m_per_group= 256, n=4096, k=7168):  465 us | throughput: 1036 TFLOPS, 2294 GB/s
 > Perf (num_groups=32, expected_m_per_group= 256, n=7168, k=2048):  244 us | throughput:  985 TFLOPS, 2475 GB/s

Testing grouped contiguous GEMM for deepgemm(block_m=block_n=block_k=128):
 > Perf (num_groups= 4, expected_m_per_group=8192, n=4096, k=7168): 1505 us | throughput: 1279 TFLOPS,  412 GB/s
 > Perf (num_groups= 4, expected_m_per_group=8192, n=7168, k=2048):  796 us | throughput: 1209 TFLOPS,  748 GB/s
 > Perf (num_groups= 8, expected_m_per_group=4096, n=4096, k=7168): 1504 us | throughput: 1279 TFLOPS,  491 GB/s
 > Perf (num_groups= 8, expected_m_per_group=4096, n=7168, k=2048):  798 us | throughput: 1206 TFLOPS,  820 GB/s
 > Perf (num_groups=32, expected_m_per_group= 256, n=4096, k=7168):  482 us | throughput:  998 TFLOPS, 2211 GB/s
 > Perf (num_groups=32, expected_m_per_group= 256, n=7168, k=2048):  256 us | throughput:  939 TFLOPS, 2358 GB/s

Testing grouped contiguous GEMM for tk(block_m=block_n=block_k=128)):
 > Perf (num_groups= 4, expected_m_per_group=8192, n=4096, k=7168): 2150 us | throughput:  895 TFLOPS,  289 GB/s
 > Perf (num_groups= 4, expected_m_per_group=8192, n=7168, k=2048): 1137 us | throughput:  846 TFLOPS,  524 GB/s
 > Perf (num_groups= 8, expected_m_per_group=4096, n=4096, k=7168): 2148 us | throughput:  896 TFLOPS,  344 GB/s
 > Perf (num_groups= 8, expected_m_per_group=4096, n=7168, k=2048): 1136 us | throughput:  847 TFLOPS,  576 GB/s
 > Perf (num_groups=32, expected_m_per_group= 256, n=4096, k=7168):  595 us | throughput:  808 TFLOPS, 1790 GB/s
 > Perf (num_groups=32, expected_m_per_group= 256, n=7168, k=2048):  315 us | throughput:  763 TFLOPS, 1916 GB/s

Note that: the current TK implementation using multicast has about 15% performance regression compared to a non-multicast path.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants