From f663b95314321c1b7d0992eac610cc09af1d6330 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 9 Dec 2025 14:25:38 -0600 Subject: [PATCH] vulkan: Multi-pass softmax for large number of cols When the number of cols is large, split each row across multiple workgroups. There are three phases that communicate partial results through temp buffers: (1) compute max partials (2) take max of partials, compute sum(exp(x-max)) partials (3) sum partials, compute scaled result --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 64 ++++++++++++++- .../vulkan-shaders/soft_max_large1.comp | 62 +++++++++++++++ .../vulkan-shaders/soft_max_large2.comp | 79 +++++++++++++++++++ .../vulkan-shaders/soft_max_large3.comp | 65 +++++++++++++++ .../vulkan-shaders/soft_max_large_common.glsl | 53 +++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 7 ++ tests/test-backend-ops.cpp | 3 + 7 files changed, 331 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c6f5809ccd7..4ed1981d029 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -722,6 +722,11 @@ struct vk_device_struct { vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; vk_pipeline pipeline_soft_max_back_f32; + + vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16; + vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16; + vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16; + vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16; vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; @@ -3996,6 +4001,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32, "soft_max_large1_f32", soft_max_large1_f32_len, soft_max_large1_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32, "soft_max_large2_f32", soft_max_large2_f32_len, soft_max_large2_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32, "soft_max_large3_f32", soft_max_large3_f32_len, soft_max_large3_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, "soft_max_large1_f32_f16", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, "soft_max_large2_f32_f16", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, "soft_max_large3_f32_f16", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -10111,7 +10123,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, { + vk_op_soft_max_push_constants pc { ncols, src1 != nullptr ? nrows_y : (uint32_t)0, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], @@ -10122,7 +10134,55 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, n_head_log2, nrows_x, src2 != nullptr - }); + }; + + if (ncols <= 16384) { + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc)); + } else { + + vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0); + vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a; + vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a; + vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst); + + uint32_t elems_per_wg = 128 * 4; + uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg); + size_t tmp_size = num_wgs * nrows_x * sizeof(float); + + if (ctx->prealloc_size_x < tmp_size) { + ctx->prealloc_size_x = tmp_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_size_y < tmp_size) { + ctx->prealloc_size_y = tmp_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size }; + vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size }; + + std::array elements = { num_wgs, nrows_x, 1 }; + + vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32; + vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32; + vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32; + + ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1); + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements); + + ctx->prealloc_x_need_sync = true; + ctx->prealloc_y_need_sync = true; + } } static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp new file mode 100644 index 00000000000..39c46639122 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp @@ -0,0 +1,62 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + float slope = get_slope(rowx); + + // Find max + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + FLOAT_TYPE a = FLOAT_TYPE(0); + if (col < p.KX) { + a = data_a[rowx * p.KX + col]; + } + + FLOAT_TYPE b = FLOAT_TYPE(0); + if (p.KY > 0 && col < p.KX) { + b = data_b[rowy_start + col]; + } + + FLOAT_TYPE v = a * p.scale + slope * b; + + if (col < p.KX) { + max_val = max(max_val, v); + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(vals[tid], vals[tid + s]); + } + barrier(); + } + + if (tid == 0) { + max_val = vals[0]; + data_m[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = max_val; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp new file mode 100644 index 00000000000..69524f5f756 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp @@ -0,0 +1,79 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + float slope = get_slope(rowx); + + // Find max + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + + [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) { + if (i + tid < gl_NumWorkGroups.x) { + max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]); + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(max_val, vals[tid + s]); + } + barrier(); + } + + max_val = vals[0]; + barrier(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + // Compute sum{exp(x - max)} + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + break; + } + + // compute exp(a*scale+b*slope), add it to sum + const uint i = rowx * p.KX + col; + FLOAT_TYPE val; + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val); + sum += val; + data_d[i] = D_TYPE(val); + } + + // reduce across the workgroup + vals[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + + if (tid == 0) { + sum = vals[0]; + data_s[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = sum; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp new file mode 100644 index 00000000000..06efd7d9fb4 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp @@ -0,0 +1,65 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) { + if (i + tid < gl_NumWorkGroups.x) { + max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]); + sum += data_s[rowx * gl_NumWorkGroups.x + i + tid]; + } + } + + // reduce across the workgroup + vals[tid] = max_val; + sumsh[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(max_val, vals[tid + s]); + sumsh[tid] += sumsh[tid + s]; + } + barrier(); + } + + max_val = vals[0]; + sum = sumsh[0]; + + if (p.has_sinks != 0) { + sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val)); + } + + FLOAT_TYPE rcpdivisor = 1.0/sum; + + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + continue; + } + + data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl new file mode 100644 index 00000000000..6636d1f8dea --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl @@ -0,0 +1,53 @@ +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + uint ne00; + uint ne01; + uint ne02; + uint ne12; + uint ne13; + uint nb11; + uint nb12; + uint nb13; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; + uint nrows_x; + uint has_sinks; +} p; + +#include "types.glsl" + +layout(constant_id = 0) const uint BLOCK_SIZE = 128; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +layout(constant_id = 1) const uint num_iters = 4; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) readonly buffer Z {float data_c[];}; +layout (binding = 3) buffer D {D_TYPE data_d[];}; +layout (binding = 4) buffer M {float data_m[];}; +layout (binding = 5) buffer S {float data_s[];}; + +shared FLOAT_TYPE vals[BLOCK_SIZE]; + +float get_slope(uint rowx) { + float slope = 1.0f; + + // ALiBi + if (p.max_bias > 0.0f) { + const uint h = (rowx / p.ne01) % p.ne02; // head index + + const float base = h < p.n_head_log2 ? p.m0 : p.m1; + const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; + + slope = pow(base, exp); + } + + return slope; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 92bae088b20..72c63e81732 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -899,6 +899,13 @@ void process_shaders() { string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large1_f32", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large2_f32", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large3_f32", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large1_f32_f16", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large2_f32_f16", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large3_f32_f16", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2e94a53da25..a0e7f663cfb 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7627,6 +7627,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); + for (float max_bias : {0.0f, 8.0f}) { for (float scale : {1.0f, 0.1f}) { for (int64_t ne0 : {16, 1024}) {