diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c6f5809ccd7..52e7e1e7fea 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -757,7 +757,8 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_split_k_reduce; - vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT]; + // [2] is for whether to take n_experts from spec constant (0) or push constant (1) + vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2]; std::vector all_pipelines; @@ -1149,6 +1150,7 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); struct vk_op_topk_moe_push_constants { uint32_t n_rows; + uint32_t n_experts_push; uint32_t n_expert_used; float clamp_min; float clamp_max; @@ -4204,10 +4206,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); - for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + for (uint32_t use_push = 0; use_push < 2; ++use_push) { + for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + } } for (auto &c : compiles) { @@ -8554,7 +8558,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); GGML_ASSERT(idx < num_topk_moe_pipelines); topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); - return ctx->device->pipeline_topk_moe[idx][mode]; + // use n_experts from push constant if it's not equal to the power of two spec constant + bool use_push = dst->ne[0] != (1u << idx); + return ctx->device->pipeline_topk_moe[idx][mode][use_push]; } if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { @@ -10158,6 +10164,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, vk_op_topk_moe_push_constants pc {}; pc.n_rows = n_rows; + pc.n_experts_push = n_experts; pc.n_expert_used = n_expert_used; if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) { ggml_tensor * clamp = cgraph->nodes[node_idx + 7]; @@ -12832,8 +12839,7 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc } const int n_expert = softmax->ne[0]; - // n_expert must be a power of 2 - if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) { + if (n_expert > (1 << (num_topk_moe_pipelines-1))) { return false; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index 5cd0785d20f..b83a2b9d2d4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -10,6 +10,7 @@ layout (push_constant) uniform parameter { uint n_rows; + uint n_experts_push; uint n_expert_used; float clamp_min; float clamp_max; @@ -18,11 +19,16 @@ layout (push_constant) uniform parameter layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; layout(constant_id = 0) const uint WARP_SIZE = 32; -layout(constant_id = 1) const uint n_experts = 512; +layout(constant_id = 1) const uint n_experts_spec = 512; layout(constant_id = 2) const bool with_norm = true; layout(constant_id = 3) const bool late_softmax = false; +layout(constant_id = 4) const bool nexperts_use_push = false; -const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; +uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE); layout (binding = 0, std430) readonly buffer Logits {float logits[];}; layout (binding = 1, std430) writeonly buffer Weights {float weights[];}; @@ -94,7 +100,7 @@ void main() { } if (!late_softmax) { - softmax_warp_inplace(wt, n_experts, lane, false); + softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push); } // at this point, each thread holds a portion of softmax, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2e94a53da25..b19458dee65 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7927,8 +7927,12 @@ static std::vector> make_test_cases_eval() { for (bool with_norm : {false, true}) { test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm)); + test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm)); test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm)); + test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm)); + test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm)); test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm)); + test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm)); } test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));