-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[CUDA] Correct after_gather_dim for nibbled uint4 index #26484
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes the kernel indexing calculation for packed uint8_t data with bits < 8 in the GatherBlockQuantized operation. When sub-byte quantization is used (e.g., 4-bit values packed into uint8_t), the output dimensions are expanded to account for unpacking, but the after_gather_dim parameter passed to the kernel was not adjusted accordingly, leading to incorrect indexing.
- Introduced calculation for
after_gather_dim_unpackedthat accounts for packed data expansion when using sub-8-bit quantization with uint8_t - Updated the kernel parameter to use the unpacked dimension value for correct indexing in the CUDA kernel
Comments suppressed due to low confidence (1)
onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc:1
- The kernel uses
after_gather_dim(unpacked value) for indexing into the output, but constructsin_idxfor the input data which is still packed. When T1 is uint8_t with bits < 8, the input data is packed, soin_idxshould be computed using the original packedafter_gather_dimvalue, not the unpacked one. This mismatch could cause incorrect memory access when reading from the packed input data.
// Copyright (c) Microsoft Corporation. All rights reserved.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Please add a test case for this. I noticed that CUDA has disabled a test case of 4 bits:
For reference, here is AI's analysis of this code change: Looks good. This change is a necessary fix for correct indexing when using Here's a breakdown of the review: Summary of the ChangeThis PR modifies the Specifically, if the input data type Analysis
Conclusion: This is a correct and necessary fix. The change is clear, well-commented, and aligns with the existing logic for handling packed |
|
@tianleiwu @kunal-vaishnavi Added testcases. Please check. |
|
Related tests are built successfully: and passed: |
Description
The after_gather_dim in CUDA backend now only supports uint8 dtype.
This PR ensures indexing matches correctly in gather_block_quantized with nibbled 4bits weights.
Motivation and Context
This allows token_embeddings and lm_head tied in 4bit weights, which saves more room and compresses models further.