Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix (kernel): 修复attention算子中的concat
  • Loading branch information
PanZezhong1725 committed Feb 21, 2024
commit bfa8e9ff14de6b105ece7cead5f08a859e2dcca5
25 changes: 13 additions & 12 deletions src/04kernel/src/kernels/attention/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,13 @@ namespace refactor::kernel {
void const *__restrict__ value,
dim_t pageStrideI,
dim_t pageStrideO,
dim_t lineStride,
dim_t pastOffset) {

auto tid = blockIdx.x * blockDim.x + threadIdx.x,
dst = tid / pageStrideO * pageStrideI + pastOffset + tid % pageStrideO;
reinterpret_cast<float4 *>(cache)[dst] = reinterpret_cast<float4 const *>(value)[tid];
dim_t pastOffset,
dim_t n_items) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n_items) {
auto dst = tid / pageStrideI * pageStrideO + pastOffset + (tid % pageStrideI);
reinterpret_cast<float4 *>(cache)[dst] = reinterpret_cast<float4 const *>(value)[tid];
}
}
constexpr uint64_t DYNAMIC_WORKSPACE_SIZE = 40 << 20;// 试出来 40MiB 是够用的

Expand Down Expand Up @@ -231,7 +232,8 @@ namespace refactor::kernel {
auto q = inputs[0];
auto k = inputs[1];
auto v = inputs[2];
auto past = *reinterpret_cast<int64_t const *>(inputs[3]);
int64_t past;
cudaMemcpy(&past, inputs[3], sizeof(int64_t), cudaMemcpyDeviceToHost);
auto attLen = info.attLen(past);
auto o = reinterpret_cast<half *>(outputs[0]);
auto kCache = reinterpret_cast<half *>(outputs[1]);
Expand All @@ -242,19 +244,18 @@ namespace refactor::kernel {
auto itemsPerLine = info.headDim * sizeof(half) / sizeof(float4);
auto threads = info.batch * info.nHead * info.seqLen * itemsPerLine;
auto blocks = (threads + 1023) / 1024;

concatCache<<<blocks, 1024, 0, stream>>>(
kCache, k,
info.seqLen * itemsPerLine,
info.cacheLen * itemsPerLine,
itemsPerLine,
past * itemsPerLine);
past * itemsPerLine,
threads);
concatCache<<<blocks, 1024, 0, stream>>>(
vCache, v,
info.seqLen * itemsPerLine,
info.cacheLen * itemsPerLine,
itemsPerLine,
past * itemsPerLine);
past * itemsPerLine,
threads);
}
MatrixDescriptor
q_(MatrixLayout{
Expand Down
12 changes: 5 additions & 7 deletions src/08-01llm/src/operators/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ namespace refactor::llm {
if (pastSeqLen.dataType != DataType::I64 || pastSeqLen.shape != Shape{DimExpr(1)}) {
return Err(InferError(ERROR_MSG("Past seqlen error")));
}
auto pastSeqLenVal = pastSeqLen.data->get<int64_t>()[0];
if (maxSeqLen <= 0) {
auto pastSeqLenVal = pastSeqLen.data->get<int64_t>()[0];
return outputs(pastSeqLenVal + seqlen);
} else if (maxSeqLen >= pastSeqLenVal + seqlen) {
} else if (maxSeqLen >= 1 + seqlen) {
return outputs(maxSeqLen);
} else {
return Err(InferError(ERROR_MSG("max_seq_len must not less than seqlen")));
Expand All @@ -94,7 +94,6 @@ namespace refactor::llm {
if (pastSeqLen.dataType != DataType::I64 || pastSeqLen.shape != Shape{DimExpr(1)}) {
return Err(InferError(ERROR_MSG("Past seqlen error")));
}
auto pastSeqLenVal = pastSeqLen.data->get<int64_t>()[0];

auto const &kCahce = inputs[4],
&vCache = inputs[5];
Expand All @@ -107,15 +106,14 @@ namespace refactor::llm {
kCahce.shape[3] != kvShape[3] ||
kCahce.shape[0] != kvShape[0] ||
kCahce.shape[2] != kvShape[2] ||
kCahce.shape[3] != kvShape[3] ||
pastSeqLenVal < kCacheSeqLen ||
pastSeqLenVal < vCacheSeqLen) {
kCahce.shape[3] != kvShape[3]) {
return Err(InferError(ERROR_MSG("KV cache error")));
}

if (maxSeqLen <= 0) {
auto pastSeqLenVal = pastSeqLen.data->get<int64_t>()[0];
return outputs(pastSeqLenVal + seqlen);
} else if (maxSeqLen >= pastSeqLenVal + seqlen) {
} else if (maxSeqLen >= 1 + seqlen) {
return outputs(maxSeqLen);
} else {
return Err(InferError(ERROR_MSG("max_seq_len must not less than seqlen")));
Expand Down