Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
vulkan: fix quantizing issue when tensor is not divisible by 128
  • Loading branch information
0cc4m committed Aug 31, 2025
commit ed079a300138030a96d883ebf273fb70851f290f
39 changes: 29 additions & 10 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2919,7 +2919,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_16 : (subgroup_size_16 * 4);
uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? device->subgroup_size : (device->subgroup_size * 4);

const bool s = device->subgroup_add && device->architecture != vk_device_architecture::AMD_GCN;
const bool s = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;

for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[s], arr_dmmv_f32_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
Expand Down Expand Up @@ -2970,8 +2970,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[s], arr_dmmv_iq4_nl_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[s], arr_dmmv_mxfp4_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
}
}

#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
if (device->integer_dot_product) {
const uint32_t subgroup_size = (device->subgroup_size_control && device->vendor_id == VK_VENDOR_ID_INTEL) ? device->subgroup_min_size : device->subgroup_size;
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
Expand All @@ -2988,8 +2990,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_q8_1_f32_len, mul_mat_vec_q8_0_q8_1_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {subgroup_size, 1*rm_stdq, i+1}, 1, true);
}
}
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
}
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT

ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
Expand Down Expand Up @@ -5769,7 +5771,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub

if (dryrun) {
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
uint64_t y_sz_upd = y_sz * ne12 * ne13;
if (quantize_y) {
y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
}
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
if (
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
Expand All @@ -5781,7 +5786,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
ctx->prealloc_size_x = x_sz_upd;
}
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
ctx->prealloc_size_y = CEIL_DIV(y_sz_upd, 128) * 128;
ctx->prealloc_size_y = y_sz_upd;
}
if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
ctx->prealloc_size_split_k = split_k_size;
Expand Down Expand Up @@ -5835,7 +5840,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
} else if (quantize_y) {
d_Y = ctx->prealloc_y;
GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
} else {
d_Y = d_Qy;
y_buf_offset = qy_buf_offset;
Expand Down Expand Up @@ -5889,10 +5894,15 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
}

uint32_t y_sz_total = y_sz * ne12 * ne13;
if (quantize_y) {
y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
}

// compute
ggml_vk_matmul(
ctx, subctx, pipeline,
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
ne01, ne11, ne10,
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
Expand Down Expand Up @@ -6010,7 +6020,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&

if (dryrun) {
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
uint64_t y_sz_upd = y_sz * ne12 * ne13;
if (quantize_y) {
y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
}
if (
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
Expand All @@ -6020,7 +6033,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
ctx->prealloc_size_x = x_sz_upd;
}
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
ctx->prealloc_size_y = CEIL_DIV(y_sz_upd, 128) * 128;
ctx->prealloc_size_y = y_sz_upd;
}

// Request descriptor sets
Expand Down Expand Up @@ -6065,7 +6078,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
d_Y = ctx->prealloc_y;
} else if (quantize_y) {
d_Y = ctx->prealloc_y;
GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 128) * 128);
GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
} else {
d_Y = d_Qy;
y_buf_offset = qy_buf_offset;
Expand Down Expand Up @@ -6121,14 +6134,20 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
groups_x = CEIL_DIV(groups_x, groups_z);
}

// TODO: Clean up this whole sz * ne_2 * ne_3 thing, it hasn't been necessary for a long time
uint32_t y_sz_total = y_sz * ne12 * ne13;
if (quantize_y) {
y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
}

// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
stride_batch_x, stride_batch_y, stride_batch_d,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz_total }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });

if (x_non_contig) {
Expand Down
9 changes: 7 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,18 @@ void quantize() {
const uint ib = wgid * blocks_per_group + block_in_wg;
const uint iqs = tid % 8;

#ifndef QBLOCK_X4
if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
return;
}

#ifdef QBLOCK_X4
#else
const uint ibx4_outer = ib / 4;
const uint ibx4_inner = ib % 4;

const uint required_x4_blocks = (p.ne + 127) / 128;
if (ibx4_outer >= required_x4_blocks) {
return;
}
#endif

const uint a_idx = ib * 8 + iqs;
Expand Down