Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix(kernel): 仍然使用 cub::BlockReduce 并改正 Attention
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 19, 2024
commit f5d975684f78d2092a9c148073db3069ae39a26d
31 changes: 0 additions & 31 deletions src/04kernel/cuda/include/kernel/cuda/reduce.cuh

This file was deleted.

24 changes: 17 additions & 7 deletions src/04kernel/src/kernels/attention/cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#include "../../utilities/cuda/cublaslt_utils.cuh"
#include "cuda_kernel.hh"
#include "hardware/functions.h"
#include "kernel/cuda/reduce.cuh"
#include "kernel/cuda/functions.cuh"
#include <cub/block/block_reduce.cuh>

namespace refactor::kernel {
using K = AttentionCuda;
Expand All @@ -27,7 +28,7 @@ namespace refactor::kernel {

// gridDim.x = batch * nHead
// gridDim.y = seqLen
// blockDim.x = min(1024, attLen)
// blockDim.x = 1024
// sizeof(shared) = attLen * sizeof(float)
template<class T, class Mask>
static __global__ void softmax(
Expand All @@ -36,25 +37,34 @@ namespace refactor::kernel {
uint32_t attLen,
uint32_t bufLen) {
// 找到这个线程块对应的 attention 区域
att += (blockIdx.x * gridDim.x + gridDim.y) * bufLen;
att += (blockIdx.x * gridDim.y + blockIdx.y) * bufLen;
// 将输入装入共享内存并 cast + mask
extern __shared__ float shared[];// size = attLen = pastSeqLen + seqLen
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
shared[i] = mask(blockIdx.y, gridDim.y, i, attLen) ? float(att[i]) : -__FLT_MAX__;
}

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage tempStorage;
__shared__ float sharedMax, sharedSum;

float localMax = -1e20;
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
localMax = cub::Max()(localMax, shared[i]);
}
localMax = cuda::blockReduce(localMax, -1e20f, cub::Max());
localMax = BlockReduce(tempStorage).Reduce(localMax, cub::Max(), attLen);
if (threadIdx.x == 0) { sharedMax = localMax; }
__syncthreads();

float localSum = 1e-20;
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
localSum += shared[i] = expf(shared[i] - localMax);
localSum += shared[i] = expf(shared[i] - sharedMax);
}
localSum = cuda::blockReduce(localSum, 1e-20f, cub::Sum());
auto reciprocal = fdividef(1, localSum);
localSum = BlockReduce(tempStorage).Reduce(localSum, cub::Sum(), attLen);
if (threadIdx.x == 0) { sharedSum = localSum; }
__syncthreads();

auto reciprocal = fdividef(1, sharedSum);
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
att[i] = shared[i] * reciprocal;
}
Expand Down
12 changes: 7 additions & 5 deletions src/04kernel/src/kernels/softmax/cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "cuda_kernel.hh"
#include "kernel/cuda/reduce.cuh"
#include <cub/cub.cuh>

namespace refactor::kernel {
using namespace runtime;
Expand All @@ -18,8 +18,8 @@ namespace refactor::kernel {
template<> __device__ __forceinline__ nv_bfloat16 reciprocal<nv_bfloat16>(nv_bfloat16 x) { return hrcp(x); }

// blockDim.x === BLOCK_DIM
template<class T>
__global__ void blockSoftmaxKernel(
template<int BLOCK_DIM, class T>
__launch_bounds__(BLOCK_DIM) __global__ void blockSoftmaxKernel(
T const *__restrict x,
T *__restrict y,
int mid,
Expand All @@ -40,8 +40,10 @@ namespace refactor::kernel {
for (int i = threadIdx.x + blockDim.x; i < mid; i += blockDim.x) {
maxSumThread = MaxSum::reduce(maxSumThread, {x[id + i * stride], 1});// reduce the data to one block
}
using BlockReduce = cub::BlockReduce<MaxSum, BLOCK_DIM>;
__shared__ typename BlockReduce::TempStorage tempStorage;
__shared__ MaxSum maxSumTotal;
auto maxSumBlock = cuda::blockReduce(maxSumThread, {-__FLT_MAX__, 0}, MaxSum::reduce);
auto maxSumBlock = BlockReduce(tempStorage).Reduce(maxSumThread, MaxSum::reduce);
if (threadIdx.x == 0) {
maxSumTotal = maxSumBlock;// must set threadIdx.x = 0 write the output to memory
}
Expand Down Expand Up @@ -111,7 +113,7 @@ namespace refactor::kernel {
auto y = reinterpret_cast<T *>(outputs[0]);
int numBlocks = info.pre * info.post;
if (info.mid > 1024) {
blockSoftmaxKernel<<<numBlocks, 1024>>>(x, y, info.mid, info.post);
blockSoftmaxKernel<1024><<<numBlocks, 1024>>>(x, y, info.mid, info.post);
} else {
int blockDimX, mid = static_cast<int>(info.mid);
for (blockDimX = 32; blockDimX > 4 && mid < blockDimX; blockDimX /= 2) {}
Expand Down
17 changes: 13 additions & 4 deletions src/04kernel/test/kernels/attention/test_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using namespace hardware;
TEST(kernel, AttentionCudaNoKvCache) {
// build routine
AttentionInfo info{
.dataType = DataType::FP16,
.dataType = DataType::F32,
.batch = 1,
.nHead = 4,
.nKVHead = 4,
Expand All @@ -23,9 +23,9 @@ TEST(kernel, AttentionCudaNoKvCache) {
.concatCache = false,
.resetCache = false,
};
auto q = Tensor::share(DataType::FP16, Shape{info.batch, info.nHead, info.seqLen, info.headDim}),
k = Tensor::share(DataType::FP16, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}),
v = Tensor::share(DataType::FP16, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}),
auto q = Tensor::share(DataType::F32, Shape{info.batch, info.nHead, info.seqLen, info.headDim}),
k = Tensor::share(DataType::F32, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}),
v = Tensor::share(DataType::F32, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}),
o = q;
auto kernel = AttentionCuda::build(info);
ASSERT_TRUE(kernel);
Expand All @@ -38,6 +38,15 @@ TEST(kernel, AttentionCudaNoKvCache) {
vGpu = dev.malloc(v->bytesSize()),
oGpu = dev.malloc(o->bytesSize()),
workspace = dev.malloc(workspaceSize);
// put input data
std::vector<float>
q_(q->elementsSize(), 1),
k_(k->elementsSize(), 1),
v_(v->elementsSize(), 1),
o_(o->elementsSize());
qGpu->copyFromHost(q_.data());
kGpu->copyFromHost(k_.data());
vGpu->copyFromHost(v_.data());
// inference
{
void const *inputs[]{*qGpu, *kGpu, *vGpu};
Expand Down