-
Notifications
You must be signed in to change notification settings - Fork 14.1k
CUDA: experimental native mxfp4 support for blackwell [WIP] #17906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
Nice speedup , Master: Device 0: NVIDIA GeForce RTX 5070 Ti, compute capability 12.0, VMM: yes
PR: Device 0: NVIDIA GeForce RTX 5070 Ti, compute capability 12.0, VMM: yes
|
ggml/src/ggml-cuda/common.cuh
Outdated
| if (sign > 0.0f) { | ||
| return static_cast<uint8_t>(best_i); // 0..7 | ||
| } else { | ||
| return static_cast<uint8_t>(best_i | 0x8); // 8..15 | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be slightly more optimal to extract the sign bit from x, do a bit shift, and a logical and.
More generally, there are FP4 conversion intrinsics in the CUDA math API but I'm not sure whether they would be of use.
ggml/src/ggml-cuda/mmq.cuh
Outdated
| x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0] = compress(aux_q4[1]) << 16 | compress(aux_q4[0]); | ||
| x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1] = compress(aux_q4[3]) << 16 | compress(aux_q4[2]); | ||
| x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2] = compress(aux_q4[1] >> 4) << 16 | compress(aux_q4[0] >> 4); | ||
| x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3] = compress(aux_q4[3] >> 4) << 16 | compress(aux_q4[2] >> 4); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point in the code you should be suffering from a 4-way shared memory bank conflict.
| return 0; | ||
| } | ||
|
|
||
| const uint8_t sign_bit = x < 0.0f ? 0x8 : 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if the compiler is smart enough to do this optimization but I meant to transplant the sign bit directly without the use of conditional statements at all. So cast the float to an unsigned integer, shift 28 bits to the right, and apply & 0x8.
| // Saturate to max representable magnitude | ||
| if (ax > pos_lut[7]) { | ||
| ax = pos_lut[7]; | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // Saturate to max representable magnitude | |
| if (ax > pos_lut[7]) { | |
| ax = pos_lut[7]; | |
| } |
It should be fine to remove this since values > 6 will automatically use the last value since it will have the smallest error.
| } | ||
|
|
||
| #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) | ||
| #define MMQ_MMA_TILE_X_K_FP4 (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The resulting value is correct, I just don't think you should be calculating it like this since it will be confusing. It would be better to use something like MMQ_TILE_NE_K + 4 though ideally you would replace the hardcoded value with something that indicates where it comes from.
| case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_FP4; | ||
| #else | ||
| case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; | ||
| #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| #endif | |
| #endif // BLACKWELL_MMA_AVAILABLE |
| const int k0 = kbx * 4; | ||
| memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs a comment mentioning that the data is permuted vs. the q8_0 path and that this is handled via permutation in quantize_mmq_mxfp4.
| } | ||
|
|
||
| offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); | ||
| constexpr size_t sz = type == GGML_TYPE_MXFP4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also needs a check for BLACKWELL_MMA_AVAILABLE.
| } | ||
|
|
||
| offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); | ||
| constexpr size_t sz = type == GGML_TYPE_MXFP4 ? sizeof(block_fp4_mmq) : sizeof(block_q8_1_mmq); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
| const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE); | ||
| const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1, WARP_SIZE); | ||
| const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE); | ||
| const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs a comment to explain the permutation.
Currently WIP, trying to add native fp4 support for blackwell and beyond. To compile
-DCMAKE_CUDA_ARCHITECTURES="120a"is required.Blackwell has a
m16n8k64instruction for 4 bit (mxfp4, nvfp4 and int4) which advertises 2x throughput compared to int8 tensor cores. However at the moment this PR is10% slower than master25% faster than master on PP. The other issue is that we quantize activation to mxfp4 instead of q8, which lead to failures intest-backend-ops, however PPL tests are okay with this change (though not ruling out correctness issues)TODO:
Master
This PR:
Device 0: NVIDIA GeForce RTX 5090, compute capability 12.0, VMM: yes