Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat(kernel): 支持 kv cache attention
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 19, 2024
commit e16c6791ff4cfae247f95b49de2f7d1aa8083b23
1 change: 1 addition & 0 deletions src/04kernel/include/kernel/attributes/attention_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace refactor::kernel {

dim_t attLen(dim_t pastSeqLen) const noexcept;
size_t attSize(dim_t pastSeqLen) const noexcept;
size_t maxAttSize() const noexcept;
};

}// namespace refactor::kernel
Expand Down
4 changes: 4 additions & 0 deletions src/04kernel/src/attributes/attention_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ namespace refactor::kernel {
return batch * nHead * seqLen * attLen(pastSeqLen) * dataType.size();
}

size_t AttentionInfo::maxAttSize() const noexcept {
return batch * nHead * seqLen * (cacheLen ? cacheLen : seqLen) * dataType.size();
}

}// namespace refactor::kernel
154 changes: 151 additions & 3 deletions src/04kernel/src/kernels/attention/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ namespace refactor::kernel {
}
}

static __global__ void concatCache(
void *__restrict__ cache,
void const *__restrict__ value,
dim_t pageStrideI,
dim_t pageStrideO,
dim_t lineStride,
dim_t pastOffset) {

auto tid = blockIdx.x * blockDim.x + threadIdx.x,
dst = tid / pageStrideO * pageStrideI + pastOffset + tid % pageStrideO;
reinterpret_cast<float4 *>(cache)[dst] = reinterpret_cast<float4 const *>(value)[tid];
}
constexpr uint64_t DYNAMIC_WORKSPACE_SIZE = 40 << 20;// 试出来 40MiB 是够用的

RoutineWorkspace K::lower(Resources &res) const {
auto handle = res.fetchOrStore<CublasLtContext>()->handle;

Expand Down Expand Up @@ -125,8 +139,8 @@ namespace refactor::kernel {
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.seqLen * info.seqLen),
}) {
auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att);
auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q);
auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att, DYNAMIC_WORKSPACE_SIZE);
auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q, DYNAMIC_WORKSPACE_SIZE);
algoQK = algoQK_;
algoAV = algoAV_;
workspaceSizeQK = workspaceSizeQK_;
Expand Down Expand Up @@ -187,12 +201,146 @@ namespace refactor::kernel {
&d->algoAV,
workspaceAV, d->workspaceSizeAV,
stream);
};
}
};

return {std::move(routine), workspaceSize};
}
TODO("");
}
if (info.concatCache && !info.resetCache) {
if (info.nHead == info.nKVHead) {

// RAII for closure
struct Descriptors {
MatMulDescriptor mul;

Descriptors(AttentionInfo info)
: mul(computeTypeConvert(info.dataType),
dataTypeConvert(info.dataType)) {}
};

auto const &context = *res.fetchOrStore<CublasLtContext>();
auto d = std::make_shared<Descriptors>(info);
auto attentionSize = info.maxAttSize();
auto workspaceSize = DYNAMIC_WORKSPACE_SIZE + attentionSize;

auto routine = [d = std::move(d), info = this->info]//
(Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
auto handle = res.fetchOrStore<CublasLtContext>()->handle;
auto q = inputs[0];
auto k = inputs[1];
auto v = inputs[2];
auto past = *reinterpret_cast<int64_t const *>(inputs[3]);
auto attLen = info.attLen(past);
auto o = reinterpret_cast<half *>(outputs[0]);
auto kCache = reinterpret_cast<half *>(outputs[1]);
auto vCache = reinterpret_cast<half *>(outputs[2]);
auto att = reinterpret_cast<half *>(reinterpret_cast<uint8_t *>(workspace) + DYNAMIC_WORKSPACE_SIZE);
auto stream = cudaStreamLegacy;
{
auto itemsPerLine = info.headDim * sizeof(half) / sizeof(float4);
auto threads = info.batch * info.nHead * info.seqLen * itemsPerLine;
auto blocks = (threads + 1023) / 1024;

concatCache<<<blocks, 1024, 0, stream>>>(
kCache, k,
info.seqLen * itemsPerLine,
info.cacheLen * itemsPerLine,
itemsPerLine,
past * itemsPerLine);
concatCache<<<blocks, 1024, 0, stream>>>(
vCache, v,
info.seqLen * itemsPerLine,
info.cacheLen * itemsPerLine,
itemsPerLine,
past * itemsPerLine);
}
MatrixDescriptor
q_(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.seqLen),
.cols = static_cast<uint64_t>(info.headDim),
.majorStride = static_cast<int64_t>(info.headDim),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
}),
k_(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.headDim),
.cols = static_cast<uint64_t>(attLen),
.majorStride = static_cast<int64_t>(info.headDim),
.order = COL_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.cacheLen * info.headDim),
}),
v_(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(attLen),
.cols = static_cast<uint64_t>(info.headDim),
.majorStride = static_cast<int64_t>(info.headDim),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.cacheLen * info.headDim),
}),
att_(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.seqLen),
.cols = static_cast<uint64_t>(attLen),
.majorStride = static_cast<int64_t>(info.cacheLen),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.cacheLen * info.seqLen),
});
{
auto [algo, workspaceSize] = tune(
handle, d->mul,
q_, k_, att_,
DYNAMIC_WORKSPACE_SIZE);
half alpha = rsqrtf(info.headDim), beta = 0;
cublasLtMatmul(
handle, d->mul.get(),
&alpha,
q, q_.get(),
kCache, k_.get(),
&beta,
att, att_.get(),
att, att_.get(),
&algo,
workspace, workspaceSize,
stream);
}
softmax<<<dim3(info.batch * info.nHead, info.seqLen),
std::min(1024u, attLen),
attLen * sizeof(float),
stream>>>(
att, AttentionCausualMask(), attLen, info.cacheLen);
{
auto [algo, workspaceSize] = tune(
handle, d->mul,
att_, v_, q_,
DYNAMIC_WORKSPACE_SIZE);
half alpha = 1, beta = 0;
cublasLtMatmul(
handle, d->mul.get(),
&alpha,
att, att_.get(),
vCache, v_.get(),
&beta,
o, q_.get(),
o, q_.get(),
&algo,
workspace, workspaceSize,
stream);
}
};

return {std::move(routine), workspaceSize};
}
TODO("");
}

TODO("");
}

Expand Down
8 changes: 4 additions & 4 deletions src/04kernel/src/utilities/cuda/cublaslt_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,23 @@ namespace refactor::kernel::cublas {
MatMulDescriptor const &matmul,
MatrixDescriptor const &a,
MatrixDescriptor const &b,
MatrixDescriptor const &c) {
MatrixDescriptor const &c,
uint64_t maxWorkspace) {

int device;
CUDA_ASSERT(cudaGetDevice(&device));
cudaDeviceProp prop;
CUDA_ASSERT(cudaGetDeviceProperties(&prop, device));

auto workspace = std::numeric_limits<uint64_t>::max();
uint32_t alignment = prop.textureAlignment;

cublasLtMatmulPreference_t preference;
CUBLASLT_ASSERT(cublasLtMatmulPreferenceCreate(&preference));
CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute(
preference,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace,
sizeof(workspace)));
&maxWorkspace,
sizeof(maxWorkspace)));
CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute(
preference,
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
Expand Down
3 changes: 2 additions & 1 deletion src/04kernel/src/utilities/cuda/cublaslt_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ namespace refactor::kernel::cublas {
MatMulDescriptor const &,
MatrixDescriptor const &,
MatrixDescriptor const &,
MatrixDescriptor const &);
MatrixDescriptor const &,
uint64_t);

}// namespace refactor::kernel::cublas

Expand Down