Skip to content
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat(kernel): 完成不带 kv cache 的简单 attention
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 19, 2024
commit 20a34ac3fe9489371b785cb49f143ff236e4121d
84 changes: 43 additions & 41 deletions src/04kernel/src/kernels/attention/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,21 @@ namespace refactor::kernel {
shared[i] = mask(blockIdx.y, gridDim.y, i, attLen) ? float(att[i]) : -__FLT_MAX__;
}

// float local_max = -1e20;
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
// local_max = fmaxf(local_max, smem[i]);
// }
// local_max = functions::blockReduceMax<float>(local_max);
float localMax = -1e20;
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
localMax = cub::Max()(localMax, shared[i]);
}
localMax = cuda::blockReduce(localMax, -1e20f, cub::Max());

// float local_sum = 1e-20;
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
// float v = expf(float(smem[i]) - local_max);
// smem[i] = v;
// local_sum += v;
// }
// local_sum = functions::blockReduceSum<float>(local_sum);
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
// x[offset + i] = float(smem[i]) / local_sum;
// }
float localSum = 1e-20;
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
localSum += shared[i] = expf(shared[i] - localMax);
}
localSum = cuda::blockReduce(localSum, 1e-20f, cub::Sum());
auto reciprocal = fdividef(1, localSum);
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
att[i] = shared[i] * reciprocal;
}
}

RoutineWorkspace K::lower(Resources &res) const {
Expand Down Expand Up @@ -141,35 +140,38 @@ namespace refactor::kernel {
auto att = reinterpret_cast<half *>(workspace);
auto workspaceQK = reinterpret_cast<uint8_t *>(workspace) + hardware::alignBytes(d->attSize, 256);
auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256);

float alpha = 1, beta = 0;
cublasLtMatmul(
handle, d->mul.get(),
&alpha,
q, d->q.get(),
k, d->k.get(),
&beta,
att, d->att.get(),
att, d->att.get(),
&d->algoQK,
workspaceQK, d->workspaceSizeQK,
cudaStreamLegacy);

{
half alpha = rsqrtf(info.headDim), beta = 0;
cublasLtMatmul(
handle, d->mul.get(),
&alpha,
q, d->q.get(),
k, d->k.get(),
&beta,
att, d->att.get(),
att, d->att.get(),
&d->algoQK,
workspaceQK, d->workspaceSizeQK,
cudaStreamLegacy);
}
softmax<<<dim3(info.batch * info.nHead, info.seqLen), info.seqLen>>>(
att, causualMask, info.seqLen, info.seqLen);

cublasLtMatmul(
handle, d->mul.get(),
&alpha,
att, d->att.get(),
v, d->v.get(),
&beta,
o, d->q.get(),
o, d->q.get(),
&d->algoAV,
workspaceAV, d->workspaceSizeAV,
cudaStreamLegacy);
{
half alpha = 1, beta = 0;
cublasLtMatmul(
handle, d->mul.get(),
&alpha,
att, d->att.get(),
v, d->v.get(),
&beta,
o, d->q.get(),
o, d->q.get(),
&d->algoAV,
workspaceAV, d->workspaceSizeAV,
cudaStreamLegacy);
};
};

return {std::move(routine), workspaceSize};
}
}
Expand Down