Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added testcases that mimic usecase for shared emb_tokens/lm_head & Up…
…dated comments
  • Loading branch information
jixiongdeng committed Nov 6, 2025
commit 12a8bb2eb528c194dca357493d3f9f29b9af16bc
78 changes: 72 additions & 6 deletions onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,78 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints_8Bits) {
}
#endif

template <typename T1, typename T2, typename Tind>
void Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_4Bits() {
// This test case specific to shared 4bit token_embedding/lm_head use case on CUDA
std::vector<int> data = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7,
0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1};
std::vector<int64_t> data_shape = {2, 16};
std::vector<int> indices = {1};
std::vector<int64_t> indices_shape = {1};
std::vector<float> scales = {2.0f, 1.0f};
std::vector<int64_t> scales_shape = {2, 1};
// Explicit zero points for each row
std::vector<int> zero_points = {-2, 1};

// With explicit zero points:
// Unpacked data (row 1): [0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1] ---add offset 8--->
// Packed (add offset 8): [8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7]
// Gathered scales (row 1): scale = 1.0f, zero_point (row 1): packed: [1] ---add offset 8---> unpacked: [9]
// Expected (CUDA doesn't subtract zero point): [8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7]
std::vector<float> output = {8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f};
std::vector<int64_t> output_shape = {1, 16};

constexpr int64_t gather_axis = 0;
constexpr int64_t quantize_axis = 1; // Last axis (required for CUDA)
constexpr int64_t block_size = 16;
constexpr int64_t bits = 4;
RunUnpackedData<T1, T2, Tind>(data, data_shape, indices, indices_shape, scales, scales_shape, zero_points,
gather_axis, quantize_axis, block_size, bits, output, output_shape, true);
}

template <typename T1, typename T2, typename Tind>
void Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_8Bits() {
// This test case specific to shared 8bit token_embedding/lm_head use case on CUDA
std::vector<int> data = {-128, -127, -126, -125, -124, -123, -122, -121, -120, -119, -118, -117, -116, -115, -114, -113,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
std::vector<int64_t> data_shape = {2, 16};
std::vector<int> indices = {1};
std::vector<int64_t> indices_shape = {1};
std::vector<float> scales = {1.0f, 2.0f};
std::vector<int64_t> scales_shape = {2, 1};
// Explicit zero points
std::vector<int> zero_points = {10, -5};

// With explicit zero points:
// Unpacked data (row 1): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] ---add offset 128--->
// Packed (row1): [128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143]
// Zero point unpacked: [-5] ---add offset 128---> packed: [123]
// Dequantization: [(128-123)*2, (129-123)*2, ..., (143-123)*2] = [10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40]
std::vector<float> output = {10.f, 12.f, 14.f, 16.f, 18.f, 20.f, 22.f, 24.f, 26.f, 28.f, 30.f, 32.f, 34.f, 36.f, 38.f, 40.f};
std::vector<int64_t> output_shape = {1, 16};

constexpr int64_t gather_axis = 0;
constexpr int64_t quantize_axis = 1;
constexpr int64_t block_size = 16;
constexpr int64_t bits = 8;
RunUnpackedData<T1, T2, Tind>(data, data_shape, indices, indices_shape, scales, scales_shape, zero_points,
gather_axis, quantize_axis, block_size, bits, output, output_shape, true);
}

TEST(GatherBlockQuantizedOpTest, GatherAxis0_QuantizedAxis1_Uint8_4Bits_WithZeroPoints) {
Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_4Bits<uint8_t, float, int32_t>();
Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_4Bits<uint8_t, MLFloat16, int32_t>();
Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_4Bits<uint8_t, float, int64_t>();
Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_4Bits<uint8_t, MLFloat16, int64_t>();
}

TEST(GatherBlockQuantizedOpTest, GatherAxis0_QuantizedAxis1_Uint8_8Bits_WithZeroPoints) {
Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_8Bits<uint8_t, float, int32_t>();
Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_8Bits<uint8_t, MLFloat16, int32_t>();
Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_8Bits<uint8_t, float, int64_t>();
Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_8Bits<uint8_t, MLFloat16, int64_t>();
}

template <typename T1, typename T2, typename Tind>
void Test_GatherAxis1_WithZeroPoints() {
std::vector<int> data = {-8, -7, -6, -5,
Expand Down Expand Up @@ -709,12 +781,6 @@ TEST(GatherBlockQuantizedOpTest, GatherAxisWithZeroPointsNoPading) {

template <typename T1, typename T2, typename Tind>
void Test_GatherAxis_NoPading_4bit() {
// BUG: CUDA kernel does NOT subtract default zero point (8 for 4-bit) when no explicit zero_points provided
// Row 0 data: [-8,-7,-6,-5,...] repeated
// Packed (add offset 8): [0,1,2,3,...] repeated
// Row 0 scales: [1.0, 2.0, 1.0] (indices 0,1,2 from [1,2,1,2,1,2])
// CUDA output: [0*1, 1*1, 2*1, 3*1, ...] [0*2, 1*2, 2*2, 3*2, ...] [0*1, 1*1, 2*1, 3*1, ...]
// = [0, 1, 2, 3, ...] [0, 2, 4, 6, ...] [0, 1, 2, 3, ...]
std::vector<int> data = {
-8, -7, -6, -5, -8, -7, -6, -5, -8, -7, -6, -5, -8, -7, -6, -5,
-8, -7, -6, -5, -8, -7, -6, -5, -8, -7, -6, -5, -8, -7, -6, -5,
Expand Down
Loading