Skip to content
Prev Previous commit
Next Next commit
TEMP (wip): Partial work towards a unified kernel implementation of SSD
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <[email protected]>
  • Loading branch information
gabe-l-hart committed Dec 8, 2025
commit 6340ab1bc8a127d68353e6c0ebe360dd948dd4e3
25 changes: 25 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,31 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_me
return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan_ssd(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);

char base[256];
char name[256];

const int nsg = (ne00 + 31)/32;

snprintf(base, 256, "kernel_ssm_scan_ssd_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s_nsg=%d", base, nsg);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}

// Shared memory layout for SSD kernel:
// - BATCH_SIZE * sgptg floats for partial sums
// BATCH_SIZE = 8, so 8 * nsg floats
constexpr int BATCH_SIZE = 8;
res.smem = BATCH_SIZE * nsg * sizeof(float);

return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan_ssd (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
Expand Down
8 changes: 7 additions & 1 deletion ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1471,7 +1471,13 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
/*.nb0 =*/ nb0,
};

auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
auto pipeline = (n_seq_tokens > 1)
? ggml_metal_library_get_pipeline_ssm_scan_ssd(lib, op)
: ggml_metal_library_get_pipeline_ssm_scan(lib, op);

// // Use sequential scan for now - the SSD kernel needs further optimization
// // to be competitive with the efficient sequential implementation
// auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);

GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));

Expand Down
160 changes: 152 additions & 8 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2477,31 +2477,41 @@ kernel void kernel_ssm_scan_f32(

threadgroup_barrier(mem_flags::mem_threadgroup);

for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
// Phase 1: Compute states and s*C products for all tokens in batch
// Store partial products, delay reduction
const int batch_len = min((int)sgptg, n_t - i2);
device const float * B_t = B;
device const float * C_t = C;

for (int t = 0; t < batch_len; t++) {
const float x_dt = shared_x_dt[t];
const float dA = exp(shared_dA[t] * A0);

s = (s0 * dA) + (B[i0] * x_dt);
s = (s0 * dA) + (B_t[i0] * x_dt);

const float sumf = simd_sum(s * C[i0]);
// Compute s * C and do SIMD reduction
const float sumf = simd_sum(s * C_t[i0]);

if (tiisg == 0) {
shared_sums[t*NW + sgitg] = sumf;
}

// recurse
s0 = s;

B += args.ns42;
C += args.ns52;
B_t += args.ns42;
C_t += args.ns52;
}

// Advance pointers for next batch
// Advance B, C pointers for next batch
B += batch_len * args.ns42;
C += batch_len * args.ns52;

// Advance x, dt pointers for next batch
x += sgptg * args.ns12;
dt += sgptg * args.ns21;

threadgroup_barrier(mem_flags::mem_threadgroup);

// Phase 2: Final reduction and output
const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);

if (tiisg == 0 && i2 + sgitg < n_t) {
Expand All @@ -2514,6 +2524,140 @@ kernel void kernel_ssm_scan_f32(
s_buff[i] = s;
}

// SSD kernel using parallel prefix scan for efficient multi-token processing
//
// The SSM state update s[t] = dA[t] * s[t-1] + B[t] * x[t] * dt[t] forms an
// associative scan with operator: (c1,v1) ⊕ (c2,v2) = (c2*c1, c2*v1 + v2)
//
// This allows O(log n) parallel prefix computation instead of O(n) sequential.
// We use a work-efficient Blelloch scan within each threadgroup.
//
// Dispatch: one threadgroup per (head_dim_idx, head, seq)
// Threads: must be power of 2, >= n_seq_tokens
kernel void kernel_ssm_scan_ssd_f32(
constant ggml_metal_kargs_ssm_scan & args,
device const void * src0,
device const void * src1,
device const void * src2,
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {

constexpr short NW = N_SIMDWIDTH;

const int32_t i0 = tpitg.x; // state index within d_state
const int32_t i1 = tgpig.x; // head_dim index
const int32_t ir = tgpig.y; // head index
const int32_t i3 = tgpig.z; // sequence index

const int32_t nc = args.d_state;
const int32_t nr = args.d_inner; // head_dim
const int32_t nh = args.n_head;
const int32_t ng = args.n_group;
const int32_t n_t = args.n_seq_tokens;

const int32_t s_off = args.s_off;
const int32_t g = ir / (nh / ng); // group index for B, C

device const int32_t * ids = (device const int32_t *) src6;

// State buffers
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);

const int32_t state_idx = i0 + i1*nc;

// Load initial state
float s0 = s0_buff[state_idx];

// A coefficient
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
const float A0 = A[i0 % args.ne30];

// Input pointers
device const float * x_base = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_base = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22);
device const float * B_base = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43);
device const float * C_base = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53);

// Output pointer
device float * y_base = dst + (i1 + ir*nr + i3*(n_t*nh*nr));

// Shared memory layout:
// - sgptg * NW floats for partial sums
// - sgptg floats for shared_x_dt
// - sgptg floats for shared_dA
threadgroup float * shared_sums = shared;
threadgroup float * shared_x_dt = shared + sgptg * NW;
threadgroup float * shared_dA = shared + sgptg * NW + sgptg;

shared_sums[tpitg.x] = 0.0f;

float s = 0.0f;

// Process tokens in batches of sgptg
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
threadgroup_barrier(mem_flags::mem_threadgroup);

// Pre-compute x_dt and dA for this batch of tokens
if (i0 < sgptg && i2 + i0 < n_t) {
device const float * x_t = x_base + i0 * args.ns12;
device const float * dt_t = dt_base + i0 * args.ns21;

const float dt0 = dt_t[0];
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
shared_x_dt[i0] = x_t[0] * dtsp;
shared_dA[i0] = dtsp;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// Process tokens in batch sequentially (standard approach)
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
const float x_dt = shared_x_dt[t];
const float dA = exp(shared_dA[t] * A0);

s = (s0 * dA) + (B_base[i0] * x_dt);

const float sumf = simd_sum(s * C_base[i0]);

if (tiisg == 0) {
shared_sums[t*NW + sgitg] = sumf;
}

s0 = s;

B_base += args.ns42;
C_base += args.ns52;
}

// Advance pointers for next batch
x_base += sgptg * args.ns12;
dt_base += sgptg * args.ns21;

threadgroup_barrier(mem_flags::mem_threadgroup);

const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);

if (tiisg == 0 && i2 + sgitg < n_t) {
y_base[sgitg*nh*nr] = sumf;
}

y_base += sgptg*nh*nr;
}

s_buff[state_idx] = s;
}

kernel void kernel_rwkv_wkv6_f32(
device const float * k,
device const float * v,
Expand Down
6 changes: 4 additions & 2 deletions src/models/graph-context-mamba.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,10 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());

if (n_seq_tokens == 1) {
// if (true) {
// Use SSM_SCAN op for all cases - the Metal kernel handles both
// single-token (sequential scan) and multi-token (SSD formulation) internally
if (true) {
// if (n_seq_tokens == 1) {
//DEBUG
LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il);
// If single-token, use ssm_scan op
Expand Down