Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: 开始实现 attention
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 19, 2024
commit cbe7c788e2ae8cccffdac8ab70615db2f8426b68
6 changes: 6 additions & 0 deletions src/02hardware/include/hardware/devices/nvidia.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

#include "../device.h"

#define CUDA_ASSERT(STATUS) \
if (auto status = (STATUS); status != cudaSuccess) { \
RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
cudaGetErrorString(status), (int) status)); \
}

namespace refactor::hardware {

class Nvidia final : public Device {
Expand Down
6 changes: 0 additions & 6 deletions src/02hardware/src/devices/nvidia/device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@
#ifdef USE_CUDA
#include "memory.hh"
#include <cuda_runtime.h>

#define CUDA_ASSERT(STATUS) \
if (auto status = (STATUS); status != cudaSuccess) { \
RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
cudaGetErrorString(status), (int) status)); \
}
#endif

namespace refactor::hardware {
Expand Down
8 changes: 1 addition & 7 deletions src/02hardware/src/devices/nvidia/memory.cc
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
#ifdef USE_CUDA

#include "memory.hh"
#include "common.h"
#include "hardware/devices/nvidia.h"
#include <cuda_runtime.h>

#define CUDA_ASSERT(STATUS) \
if (auto status = (STATUS); status != cudaSuccess) { \
RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
cudaGetErrorString(status), (int) status)); \
}

namespace refactor::hardware {
using M = NvidiaMemory;

Expand Down
16 changes: 16 additions & 0 deletions src/04kernel/include/kernel/attributes/attention_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef KERNEL_ATTENTION_INFO_H
#define KERNEL_ATTENTION_INFO_H

#include "../tensor.h"

namespace refactor::kernel {

struct AttentionInfo {
DataType dataType;
dim_t batch, nHead, nKVHead, seqLen, headDim, cacheLen;
bool concatCache, resetCache;
};

}// namespace refactor::kernel

#endif// KERNEL_ATTENTION_INFO_H
3 changes: 1 addition & 2 deletions src/04kernel/include/kernel/collectors/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
namespace refactor::kernel {

struct AttentionCollector final : public InfoCollector {
dim_t maxSeqLen;

AttentionCollector(decltype(_target), decltype(maxSeqLen)) noexcept;
AttentionCollector(decltype(_target)) noexcept;

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
Expand Down
55 changes: 37 additions & 18 deletions src/04kernel/src/collectors/attention.cc
Original file line number Diff line number Diff line change
@@ -1,38 +1,57 @@
#include "kernel/collectors/attention.h"
#include "kernel/attributes/attention_info.h"
// #include "../kernels/attention/cpu_kernel.hh"
#include "../kernels/attention/cuda_kernel.hh"

namespace refactor::kernel {

AttentionCollector::AttentionCollector(
decltype(_target) target,
decltype(maxSeqLen) maxSeqLen_) noexcept
: InfoCollector(target),
maxSeqLen(maxSeqLen_) {}
decltype(_target) target) noexcept
: InfoCollector(target) {}

std::vector<KernelBox>
AttentionCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
auto const &query = inputs[0].get();
auto const &key = inputs[1].get();
auto pastSeqLen = inputs.size() == 3 ? 0 : *inputs[2].get().data->get<int64_t>();
auto cacheLen = outputs.size() == 1 ? 0 : outputs[1].get().shape[2];

std::vector<KernelBox> ans;
AttentionInfo info{
.dataType = query.dataType,
.batch = query.shape[0],
.nHead = query.shape[1],
.nKVHead = key.shape[1],
.seqLen = query.shape[2],
.headDim = query.shape[3],
.cacheLen = 0,
.concatCache = false,
.resetCache = false,
};
switch (outputs.size()) {
case 1:
// no kv cache
ASSERT(inputs.size() == 3, "");
break;
case 3:
switch (inputs.size()) {
case 6:
info.resetCache = true;
case 4:
info.concatCache = true;
case 3:
info.cacheLen = outputs[1].get().shape[2];
break;
default:
UNREACHABLE();
}
break;
default:
UNREACHABLE();
}

std ::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
break;
case decltype(_target)::Nvidia: {
decltype(AttentionCuda::info) info{
.dataType = query.dataType,
.batch = query.shape[0],
.nHead = query.shape[1],
.nKVHead = key.shape[1],
.pastSeqLen = static_cast<dim_t>(pastSeqLen),
.seqLen = query.shape[2],
.cacheLen = cacheLen,
.headDim = query.shape[3],
.resetCache = false,
};
if (auto ptr = AttentionCuda::build(info); ptr) {
ans.emplace_back(std::move(ptr));
}
Expand Down
127 changes: 127 additions & 0 deletions src/04kernel/src/kernels/attention/cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#include "../../utilities/cuda/cublaslt_utils.cuh"
#include "cuda_kernel.hh"
#include "hardware/functions.h"

namespace refactor::kernel {
using K = AttentionCuda;
using namespace cublas;

RoutineWorkspace K::lower(Resources &res) const {
auto handle = res.fetchOrStore<CublasLtContext>()->handle;

constexpr auto ROW_MAJOR = CUBLASLT_ORDER_ROW;
constexpr auto COL_MAJOR = CUBLASLT_ORDER_COL;

if (!info.cacheLen) {
if (info.nHead == info.nKVHead) {
// RAII for closure
struct Descriptors {
MatMulDescriptor mul;
MatrixDescriptor q, k, v, att;
cublasLtMatmulAlgo_t algoQK, algoAV;
size_t attSize, workspaceSizeQK, workspaceSizeAV;

Descriptors(CublasLtContext const &context,
cublasComputeType_t compute,
AttentionInfo info)
: mul(compute, CUDA_R_32F),
q(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.seqLen),
.cols = static_cast<uint64_t>(info.headDim),
.majorStride = static_cast<int64_t>(info.headDim),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
}),
k(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.headDim),
.cols = static_cast<uint64_t>(info.seqLen),
.majorStride = static_cast<int64_t>(info.headDim),
.order = COL_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
}),
v(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.seqLen),
.cols = static_cast<uint64_t>(info.headDim),
.majorStride = static_cast<int64_t>(info.headDim),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
}),
att(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.seqLen),
.cols = static_cast<uint64_t>(info.seqLen),
.majorStride = static_cast<int64_t>(info.seqLen),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.seqLen * info.seqLen),
}),
attSize(info.batch * info.nHead * info.seqLen * info.seqLen * info.dataType.size()) {
auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att);
auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q);
algoQK = algoQK_;
algoAV = algoAV_;
workspaceSizeQK = workspaceSizeQK_;
workspaceSizeAV = workspaceSizeAV_;
}
};

auto const &context = *res.fetchOrStore<CublasLtContext>();
auto d = std::make_shared<Descriptors>(context, CUBLAS_COMPUTE_32F, info);
auto workspaceSize = d->attSize;
workspaceSize = hardware::alignBytes(workspaceSize, 256);
workspaceSize += d->workspaceSizeQK;
workspaceSize = hardware::alignBytes(workspaceSize, 256);
workspaceSize += d->workspaceSizeAV;
workspaceSize = hardware::alignBytes(workspaceSize, 256);

auto routine = [d = std::move(d), info = this->info]//
(Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
auto handle = res.fetchOrStore<CublasLtContext>()->handle;
auto q = inputs[0];
auto k = inputs[1];
auto v = inputs[2];
auto o = outputs[0];
auto att = workspace;
auto workspaceQK = reinterpret_cast<uint8_t *>(workspace) + hardware::alignBytes(d->attSize, 256);
auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256);

float alpha = 1, beta = 0;
cublasLtMatmul(
handle, d->mul.get(),
&alpha,
q, d->q.get(),
k, d->k.get(),
&beta,
att, d->att.get(),
att, d->att.get(),
&d->algoQK,
workspaceQK, d->workspaceSizeQK,
cudaStreamLegacy);

// TODO inline mask && softmax

cublasLtMatmul(
handle, d->mul.get(),
&alpha,
att, d->att.get(),
v, d->v.get(),
&beta,
o, d->q.get(),
o, d->q.get(),
&d->algoAV,
workspaceAV, d->workspaceSizeAV,
cudaStreamLegacy);
};
return {std::move(routine), workspaceSize};
}
}
TODO("");
}

}// namespace refactor::kernel
8 changes: 2 additions & 6 deletions src/04kernel/src/kernels/attention/cuda_kernel.hh
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
#ifndef KERNEL_ATTENTION_CUDA_KERNEL_HH
#define KERNEL_ATTENTION_CUDA_KERNEL_HH

#include "kernel/attributes/attention_info.h"
#include "kernel/kernel.h"
#include "kernel/tensor.h"

namespace refactor::kernel {

struct AttentionCuda final : public Kernel {
struct {
DataType dataType;
dim_t batch, nHead, nKVHead, pastSeqLen, seqLen, cacheLen, headDim;
bool resetCache;
} info;
AttentionInfo info;

AttentionCuda(decltype(info)) noexcept;

Expand Down
33 changes: 0 additions & 33 deletions src/04kernel/src/utilities/cuda/cublaslt_context.cu

This file was deleted.

33 changes: 0 additions & 33 deletions src/04kernel/src/utilities/cuda/cublaslt_context.hh

This file was deleted.

Loading