Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor(kernel): 实现一种不依赖模板参数的 BlockReduce 并用于 softmax
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 19, 2024
commit e54c7db38cecea327bb75d249ed223f2b332279c
24 changes: 23 additions & 1 deletion src/04kernel/cuda/include/kernel/cuda/reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,28 @@
#include <cub/warp/warp_reduce.cuh>

namespace refactor::kernel::cuda {
}

template<class T, class ReductionOp>
__inline__ __device__ T blockReduce(T x, T init, ReductionOp op) {
using WarpReduce = cub::WarpReduce<T>;
__shared__ typename WarpReduce::TempStorage tempStorage;
__shared__ T shared[32], ans;

auto reduce = WarpReduce(tempStorage);
int lane = threadIdx.x % 32;
int wid = threadIdx.x / 32;
x = reduce.Reduce(x, op);
if (lane == 0) { shared[wid] = x; }
__syncthreads();
if (wid == 0) {
x = (threadIdx.x < blockDim.x / 32) ? shared[lane] : init;
shared[lane] = reduce.Reduce(x, op);
if (lane == 0) { ans = shared[0]; }
}
__syncthreads();
return ans;// avoid RAW hazard
}

}// namespace refactor::kernel::cuda

#endif// KERNEL_CUDA_REDUCE_CUH
12 changes: 5 additions & 7 deletions src/04kernel/src/kernels/softmax/cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "cuda_kernel.hh"
#include <cub/cub.cuh>
#include "kernel/cuda/reduce.cuh"

namespace refactor::kernel {
using namespace runtime;
Expand All @@ -18,8 +18,8 @@ namespace refactor::kernel {
template<> __device__ __forceinline__ nv_bfloat16 reciprocal<nv_bfloat16>(nv_bfloat16 x) { return hrcp(x); }

// blockDim.x === BLOCK_DIM
template<int BLOCK_DIM, class T>
__launch_bounds__(BLOCK_DIM) __global__ void blockSoftmaxKernel(
template<class T>
__global__ void blockSoftmaxKernel(
T const *__restrict x,
T *__restrict y,
int mid,
Expand All @@ -40,10 +40,8 @@ namespace refactor::kernel {
for (int i = threadIdx.x + blockDim.x; i < mid; i += blockDim.x) {
maxSumThread = MaxSum::reduce(maxSumThread, {x[id + i * stride], 1});// reduce the data to one block
}
using BlockReduce = cub::BlockReduce<MaxSum, BLOCK_DIM>;
__shared__ typename BlockReduce::TempStorage tempStorage;
__shared__ MaxSum maxSumTotal;
auto maxSumBlock = BlockReduce(tempStorage).Reduce(maxSumThread, MaxSum::reduce);
auto maxSumBlock = cuda::blockReduce(maxSumThread, {-__FLT_MAX__, 0}, MaxSum::reduce);
if (threadIdx.x == 0) {
maxSumTotal = maxSumBlock;// must set threadIdx.x = 0 write the output to memory
}
Expand Down Expand Up @@ -113,7 +111,7 @@ namespace refactor::kernel {
auto y = reinterpret_cast<T *>(outputs[0]);
int numBlocks = info.pre * info.post;
if (info.mid > 1024) {
blockSoftmaxKernel<1024><<<numBlocks, 1024>>>(x, y, info.mid, info.post);
blockSoftmaxKernel<<<numBlocks, 1024>>>(x, y, info.mid, info.post);
} else {
int blockDimX, mid = static_cast<int>(info.mid);
for (blockDimX = 32; blockDimX > 4 && mid < blockDimX; blockDimX /= 2) {}
Expand Down
19 changes: 13 additions & 6 deletions src/04kernel/test/kernels/softmax/test_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
#include "../../../src/kernels/softmax/cuda_kernel.hh"
#include "hardware/device_manager.h"
#include <gtest/gtest.h>
#include <numeric>

using namespace refactor;
using namespace kernel;
using namespace hardware;

TEST(kernel, SoftmaxCuda) {
static void test(Shape shape, int axis) {
// build routine
auto xTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4});
auto outTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4});
dim_t axis = 1;
auto kCpu = SoftmaxCpu::build(SoftmaxInfo(*xTensor, axis));
auto kCuda = SoftmaxCuda::build(SoftmaxInfo(*xTensor, axis));
auto xTensor = Tensor::share(DataType::F32, shape);
auto outTensor = Tensor::share(DataType::F32, shape);
SoftmaxInfo info(*xTensor, axis);
auto kCpu = SoftmaxCpu::build(info);
auto kCuda = SoftmaxCuda::build(info);
ASSERT_TRUE(kCpu && kCuda);
auto res = runtime::Resources();
auto rCpu = kCpu->lower(res).routine;
Expand All @@ -28,6 +29,7 @@ TEST(kernel, SoftmaxCuda) {
std::vector<float>
data(xTensor->elementsSize(), 0),
cpuOut(outTensor->elementsSize());
std::iota(data.begin(), data.end(), 0);
gpuIn->copyFromHost(data.data(), xTensor->bytesSize());
// inference
{
Expand All @@ -49,4 +51,9 @@ TEST(kernel, SoftmaxCuda) {
}
}

TEST(kernel, SoftmaxCuda) {
test({2, 3, 2, 5, 4}, 1);
test({2, 2048, 2, 5, 4}, 1);
}

#endif