Skip to content

Conversation

@jixiongdeng
Copy link
Contributor

Description

The after_gather_dim in CUDA backend now only supports uint8 dtype.
This PR ensures indexing matches correctly in gather_block_quantized with nibbled 4bits weights.

Motivation and Context

This allows token_embeddings and lm_head tied in 4bit weights, which saves more room and compresses models further.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes the kernel indexing calculation for packed uint8_t data with bits < 8 in the GatherBlockQuantized operation. When sub-byte quantization is used (e.g., 4-bit values packed into uint8_t), the output dimensions are expanded to account for unpacking, but the after_gather_dim parameter passed to the kernel was not adjusted accordingly, leading to incorrect indexing.

  • Introduced calculation for after_gather_dim_unpacked that accounts for packed data expansion when using sub-8-bit quantization with uint8_t
  • Updated the kernel parameter to use the unpacked dimension value for correct indexing in the CUDA kernel
Comments suppressed due to low confidence (1)

onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc:1

  • The kernel uses after_gather_dim (unpacked value) for indexing into the output, but constructs in_idx for the input data which is still packed. When T1 is uint8_t with bits < 8, the input data is packed, so in_idx should be computed using the original packed after_gather_dim value, not the unpacked one. This mismatch could cause incorrect memory access when reading from the packed input data.
// Copyright (c) Microsoft Corporation. All rights reserved.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@tianleiwu
Copy link
Contributor

tianleiwu commented Nov 3, 2025

Please add a test case for this.

I noticed that CUDA has disabled a test case of 4 bits:


For reference, here is AI's analysis of this code change:

Looks good. This change is a necessary fix for correct indexing when using uint8_t to store packed 4-bit data.

Here's a breakdown of the review:

Summary of the Change

This PR modifies the ComputeInternal function in gather_block_quantized.cc. It introduces after_gather_dim_unpacked to correctly calculate the after_gather_dim parameter that is passed to the CUDA kernel.

Specifically, if the input data type T1 is uint8_t and bits is less than 8 (e.g., 4), it means uint8_t is being used as a container for multiple packed values (e.g., two 4-bit "nibbles").

Analysis

  1. Problem: The GatherBlockQuantizedKernel in gather_block_quantized.cu (line 51) calculates an index out_idx based on the output tensor. It then uses after_gather_dim (line 56) to map this output index back to the corresponding input element.
  2. Shape Mismatch: The ComputeInternal function already correctly calculates the output shape by expanding the last dimension by components (lines 99-102) to account for the unpacked data. However, the after_gather_dim was previously calculated based on the packed input data shape.
  3. The Fix: This mismatch causes incorrect indexing in the kernel. The kernel's after_gather_dim parameter must be based on the unpacked output tensor's dimensions, because out_idx iterates over the unpacked output. This change correctly multiplies after_gather_dim by the number of components, ensuring the indexing logic in the kernel is sound.
  4. Consistency:
    • This logic is consistent with the operator's shape inference function (provided in the prompt, lines 201-208), which also expands the output dimension by components.
    • It is also consistent with the test file (gather_block_quantized_op_test.cc), which includes a PackDataForUint8TypeIfNecessary function (line 35) and specific tests for 4-bit and 8-bit uint8_t data (e.g., Test_GatherAxis0_WithZeroPoints_Uint8), confirming this packing-aware logic is required.

Conclusion: This is a correct and necessary fix. The change is clear, well-commented, and aligns with the existing logic for handling packed uint8_t data in the operator's spec and tests.

@jixiongdeng
Copy link
Contributor Author

@tianleiwu @kunal-vaishnavi Added testcases. Please check.
Including the use cases for shared 4/8bit emb_tokens/lm_head. (the motivation of this PR).
Thanks @xiaomsft for sharing his unmerged test cases https://github.com/xiaomsft/onnxruntime/tree/xiaoh/gather_block_quantized_tests.
Also included UInt4X2 dtype testcases. All dtype should be covered now.

@jixiongdeng
Copy link
Contributor Author

Related tests are built successfully:

[ 88%] Building CXX object CMakeFiles/onnxruntime_provider_test.dir/onnxruntime/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc.o
[ 88%] Linking CXX executable onnxruntime_provider_test
/usr/bin/ld: warning: QgemmU8S8KernelAmx.S.o: missing .note.GNU-stack section implies executable stack
/usr/bin/ld: NOTE: This behaviour is deprecated and will be removed in a future version of the linker
[100%] Built target onnxruntime_provider_test

and passed:

[==========] Running 5 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 5 tests from GatherBlockQuantizedOpTest
[ RUN      ] GatherBlockQuantizedOpTest.GatherAxis0_QuantizedAxis1_Uint8_4Bits_WithZeroPoints
[       OK ] GatherBlockQuantizedOpTest.GatherAxis0_QuantizedAxis1_Uint8_4Bits_WithZeroPoints (2560 ms)
[ RUN      ] GatherBlockQuantizedOpTest.GatherAxis0_QuantizedAxis1_Uint8_8Bits_WithZeroPoints
[       OK ] GatherBlockQuantizedOpTest.GatherAxis0_QuantizedAxis1_Uint8_8Bits_WithZeroPoints (97 ms)
[ RUN      ] GatherBlockQuantizedOpTest.GatherAxisWithZeroPointsNoPading
[       OK ] GatherBlockQuantizedOpTest.GatherAxisWithZeroPointsNoPading (195 ms)
[ RUN      ] GatherBlockQuantizedOpTest.GatherAxisNoPadingUInt8_4Bits
[       OK ] GatherBlockQuantizedOpTest.GatherAxisNoPadingUInt8_4Bits (96 ms)
[ RUN      ] GatherBlockQuantizedOpTest.GatherAxisNoPadingUInt8
[       OK ] GatherBlockQuantizedOpTest.GatherAxisNoPadingUInt8 (100 ms)
[----------] 5 tests from GatherBlockQuantizedOpTest (3049 ms total)

[----------] Global test environment tear-down
[==========] 5 tests from 1 test suite ran. (3050 ms total)
[  PASSED  ] 5 tests.

@tianleiwu tianleiwu enabled auto-merge (squash) November 6, 2025 20:07
@tianleiwu tianleiwu merged commit d7b48f8 into main Nov 6, 2025
103 of 104 checks passed
@tianleiwu tianleiwu deleted the jdeng/shared_4bit_emb branch November 6, 2025 20:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants