Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: 添加寒武纪平台where/expand/conv算子
  • Loading branch information
Chamberlain0w0 authored and YdrMaster committed Jan 31, 2024
commit e329552d84db11a4626fe8ab08777f47e8d0f5bc
6 changes: 6 additions & 0 deletions src/04kernel/src/collectors/conv.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "kernel/collectors/conv.h"
#include "../kernels/conv/cnnl_kernel.hh"
#include "../kernels/conv/cudnn_kernel.hh"

namespace refactor::kernel {
Expand All @@ -23,6 +24,11 @@ namespace refactor::kernel {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Mlu:
if (auto ptr = ConvCnnl::build(poolAttrs, x, w, b, y); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
}
Expand Down
8 changes: 7 additions & 1 deletion src/04kernel/src/collectors/where.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#include "kernel/collectors/where.h"
#include "../kernels/where/cnnl_kernel.hh"
#include "../kernels/where/cpu_kernel.hh"
#include "../kernels/where/where_cuda.hh"

namespace refactor::kernel {

std::vector<KernelBox>
WhereCollector::filter(TensorRefs inputs, TensorRefs) const {
WhereCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
Expand All @@ -18,6 +19,11 @@ namespace refactor::kernel {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Mlu:
if (auto ptr = WhereCnnl::build(inputs, outputs); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
}
Expand Down
6 changes: 3 additions & 3 deletions src/04kernel/src/kernels/batch_normalization/cnnl_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ namespace refactor::kernel {
CNNL_ASSERT(cnnlSetTransposeDescriptor(d->NHWC2NCHW, 4, permuteOut));

auto handle = res.fetchOrStore<CnnlContext>()->handle;
auto xTransSize = cnnlGetTensorElementNum(d->inDescTrans) * sizeof(info.dtX);
auto xTransSize = cnnlGetTensorElementNum(d->inDescTrans) * info.dtX.size();
size_t workspaceSize;
CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->inDesc, d->NCHW2NHWC, &workspaceSize));
size_t totalWorkspaceSize = xTransSize + workspaceSize;
size_t totalWorkspaceSize = xTransSize * 2 + workspaceSize;

res.fetchOrStore<CnnlContext>();
auto routine = [d = std::move(d),
Expand All @@ -129,7 +129,7 @@ namespace refactor::kernel {

void *xTrans = workspace;
void *yTrans = xTrans + xTransSize;
void *cursor = yTrans + workspaceSize;
void *cursor = yTrans + xTransSize;

// transpose NCHW input to NHWC
CNNL_ASSERT(cnnlTranspose_v2(handle, d->NCHW2NHWC, d->inDesc, x,
Expand Down
243 changes: 243 additions & 0 deletions src/04kernel/src/kernels/conv/cnnl_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
#include "cnnl_kernel.hh"

#ifdef USE_BANG
#include "../../utilities/bang/cnnl_context.hh"
#include "../../utilities/bang/cnnl_functions.h"
#include "../expand/cnnl_kernel.hh"
#include "hardware/functions.h"
#endif

namespace refactor::kernel {
using K = ConvCnnl;

K::ConvCnnl(decltype(info) info_) noexcept
: Kernel(), info(std::move(info_)) {}

auto K::build(PoolAttributes const &poolAttributes,
Tensor const &x,
Tensor const &w,
std::optional<std::reference_wrapper<Tensor const>> b,
Tensor const &y) -> KernelBox {
static const std::unordered_set<decltype(DataType::internal)>
SET{DataType::FP16, DataType::BF16, DataType::F32, DataType::F64, DataType::I8};
#ifndef USE_BANG
return nullptr;
#endif

auto dt = x.dataType;
if (!SET.contains(dt) || w.dataType != dt || y.dataType != dt) {
return nullptr;
}

std::optional<ExpandInfoCnnl> biasExpand = std::nullopt;
if (b) {
ASSERT(b->get().shape[0] == y.shape[1], "");
std::vector<dim_t> input(y.rank(), 1);
input[1] = y.shape[1];
biasExpand.emplace(ExpandInfoCnnl(
b->get().dataType,
slice(input.data(), input.size()),
slice(y.shape.data(), y.rank())));
}

// group is not supported
if (w.rank() != 4 || poolAttributes.rank() != 2) {
return nullptr;
}
auto d = poolAttributes.dilations(),
p = poolAttributes.pads(),
s = poolAttributes.strides();
return std::make_unique<K>(decltype(info){
dt,
{
static_cast<int>(x.shape[0]),
static_cast<int>(x.shape[1]),
static_cast<int>(x.shape[2]),
static_cast<int>(x.shape[3]),
},
{
static_cast<int>(w.shape[0]),
static_cast<int>(w.shape[1]),
static_cast<int>(w.shape[2]),
static_cast<int>(w.shape[3]),
},
{
static_cast<int>(y.shape[0]),
static_cast<int>(y.shape[1]),
static_cast<int>(y.shape[2]),
static_cast<int>(y.shape[3]),
},
{d[0], d[1]},
{p[0], p[1], p[2], p[3]},
{s[0], s[1]},
std::move(biasExpand),
});
}

auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t { return typeId(); }
auto K::description() const noexcept -> std::string_view {
return "Performing conv using CNNL";
}

#ifdef USE_BANG

auto ConvCnnl::lower(Resources &res) const -> RoutineWorkspace {
using namespace cnnl;
using namespace runtime;

// RAII for closure
struct Descriptors {
cnnlTensorDescriptor_t x, y, w;
cnnlTensorDescriptor_t xTrans, yTrans, wTrans;
cnnlTransposeDescriptor_t NCHW2NHWC, NHWC2NCHW;
cnnlConvolutionDescriptor_t conv;
cnnlConvolutionForwardAlgo_t algo;
// std::optional<ExtraPadding> extraPadding;
std::optional<Routine> biasExpand;
bool f32;

Descriptors(decltype(f32) f32_)
:// extraPadding(std::nullopt),
biasExpand(std::nullopt),
f32(f32_) {
CNNL_ASSERT(cnnlCreateTensorDescriptor(&x));
CNNL_ASSERT(cnnlCreateTensorDescriptor(&y));
CNNL_ASSERT(cnnlCreateTensorDescriptor(&w));
CNNL_ASSERT(cnnlCreateTensorDescriptor(&xTrans));
CNNL_ASSERT(cnnlCreateTensorDescriptor(&yTrans));
CNNL_ASSERT(cnnlCreateTensorDescriptor(&wTrans));
CNNL_ASSERT(cnnlCreateTransposeDescriptor(&NCHW2NHWC));
CNNL_ASSERT(cnnlCreateTransposeDescriptor(&NHWC2NCHW));
CNNL_ASSERT(cnnlCreateConvolutionDescriptor(&conv));
}
~Descriptors() noexcept(false) {
CNNL_ASSERT(cnnlDestroyTensorDescriptor(x));
CNNL_ASSERT(cnnlDestroyTensorDescriptor(y));
CNNL_ASSERT(cnnlDestroyTensorDescriptor(w));
CNNL_ASSERT(cnnlDestroyTensorDescriptor(xTrans));
CNNL_ASSERT(cnnlDestroyTensorDescriptor(yTrans));
CNNL_ASSERT(cnnlDestroyTensorDescriptor(wTrans));
CNNL_ASSERT(cnnlDestroyTransposeDescriptor(NCHW2NHWC));
CNNL_ASSERT(cnnlDestroyTransposeDescriptor(NHWC2NCHW));
CNNL_ASSERT(cnnlDestroyConvolutionDescriptor(conv));
}

Descriptors(const Descriptors &) = delete;
Descriptors(Descriptors &&) = delete;
};
auto d = std::make_shared<Descriptors>(info.dt != DataType::F64);
// d->extraPadding = ExtraPadding::build(info.dt, info.xShape, info.pad);
if (info.biasExpand) {
d->biasExpand = ExpandCnnl(*info.biasExpand).lower(res).routine;
}
int xs[]{
info.xShape[0],
info.xShape[1],
info.xShape[2] + std::abs(info.pad[0] - info.pad[2]),
info.xShape[3] + std::abs(info.pad[1] - info.pad[3]),
};

auto NHWC = [](const int shape[]) -> std::vector<int> {
return {
shape[0], shape[2], shape[3], shape[1]};
};

std::vector<int> xsNHWC = NHWC(xs);
std::vector<int> wsNHWC = NHWC(info.wShape);
std::vector<int> ysNHWC = NHWC(info.yShape);

setCnnlTensor(d->x, info.dt, slice(xs, 4));
setCnnlTensor(d->y, info.dt, slice(info.yShape, 4));
setCnnlTensor(d->w, info.dt, slice(info.wShape, 4));
CNNL_ASSERT(cnnlSetTensorDescriptor(d->xTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(info.dt), 4, xsNHWC.data()));
CNNL_ASSERT(cnnlSetTensorDescriptor(d->yTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(info.dt), 4, ysNHWC.data()));
CNNL_ASSERT(cnnlSetTensorDescriptor(d->wTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(info.dt), 4, wsNHWC.data()));

auto xTransSize = cnnlGetTensorElementNum(d->xTrans) * info.dt.size();
auto yTransSize = cnnlGetTensorElementNum(d->yTrans) * info.dt.size();
auto wTransSize = cnnlGetTensorElementNum(d->wTrans) * info.dt.size();

int permuteIn[4] = {0, 2, 3, 1};
int permuteOut[4] = {0, 3, 1, 2};
CNNL_ASSERT(cnnlSetTransposeDescriptor(d->NCHW2NHWC, 4, permuteIn));
CNNL_ASSERT(cnnlSetTransposeDescriptor(d->NHWC2NCHW, 4, permuteOut));

size_t xWorkspaceSize, yWorkspaceSize, wWorkspaceSize, convWorkspaceSize;
auto handle = res.fetchOrStore<CnnlContext>()->handle;
CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->x, d->NCHW2NHWC, &xWorkspaceSize));
CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->w, d->NCHW2NHWC, &wWorkspaceSize));
CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->yTrans, d->NHWC2NCHW, &yWorkspaceSize));

// clang-format off
auto computation = info.dt == DataType::F64 ? DataType::F64
: info.dt == DataType::I8 ? DataType::I32
: DataType::F32;
// clang-format on
auto group = xs[1] / info.wShape[1];
CNNL_ASSERT(cnnlSetConvolutionDescriptor(d->conv, 4, info.pad, info.stride, info.dilation, group, cnnlDataTypeConvert(computation)));
CNNL_ASSERT(cnnlGetConvolutionForwardAlgorithm(
handle, d->conv, d->xTrans, d->wTrans, d->yTrans,
CNNL_CONVOLUTION_FWD_FASTEST, &d->algo));

CNNL_ASSERT(cnnlGetConvolutionForwardWorkspaceSize(
handle, d->xTrans, d->wTrans, d->yTrans, NULL,
d->conv, d->algo, &convWorkspaceSize));

// if (d->extraPadding) {
// workspaceSize = hardware::alignBytes(workspaceSize, 256);
// }

size_t workspaceSize = xTransSize + yTransSize + wTransSize + std::max({xWorkspaceSize, wWorkspaceSize, yWorkspaceSize, convWorkspaceSize});

res.fetchOrStore<CnnlContext>();
auto routine = [d, xTransSize, yTransSize, wTransSize,
xWorkspaceSize, wWorkspaceSize,
yWorkspaceSize, convWorkspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
auto handle = res.fetchOrStore<CnnlContext>()->handle;
void const *x = inputs[0], *w = inputs[1];
void *y = outputs[0];
// if (auto f = d->extraPadding; f) {
// x = (*f)(x, reinterpret_cast<uint8_t *>(workspace) + workspaceSize);
// }
// if (auto f = d->biasExpand; f) {
// (*f)(res, workspace, inputs + 2, outputs);
// }

void *xTrans = workspace;
void *wTrans = xTrans + xTransSize;
void *yTrans = wTrans + wTransSize;
void *opWorkspace = yTrans + yTransSize;

// transpose NCHW input to NHWC
CNNL_ASSERT(cnnlTranspose_v2(handle, d->NCHW2NHWC, d->x, x,
d->xTrans, xTrans, opWorkspace, xWorkspaceSize));
CNNL_ASSERT(cnnlTranspose_v2(handle, d->NCHW2NHWC, d->w, w,
d->wTrans, wTrans, opWorkspace, wWorkspaceSize));

// build alpha/beta for double
auto a = d->f32 ? factor<fp32_t>(1) : factor<fp64_t>(1),
b = d->f32
? factor<fp32_t>(d->biasExpand ? 1 : 0)
: factor<fp64_t>(d->biasExpand ? 1 : 0);
CNNL_ASSERT(cnnlConvolutionForward(
handle,
d->conv, d->algo, &a,
d->xTrans, xTrans, d->wTrans, wTrans,
NULL, NULL, opWorkspace, convWorkspaceSize,
&b, d->yTrans, yTrans));

// transpose NHWC intermediates to NCHW
CNNL_ASSERT(cnnlTranspose_v2(handle, d->NHWC2NCHW, d->yTrans, yTrans,
d->y, y, opWorkspace, yWorkspaceSize));
};
return {std::move(routine), workspaceSize};
}

#endif

}// namespace refactor::kernel
43 changes: 43 additions & 0 deletions src/04kernel/src/kernels/conv/cnnl_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef KERNEL_CONV_CNNL_KERNEL_HH
#define KERNEL_CONV_CNNL_KERNEL_HH

#include "../../kernels/expand/cnnl_kernel.hh"
#include "kernel/attributes/pool_attributes.h"
#include "kernel/kernel.h"
#include <optional>

namespace refactor::kernel {

/// @brief Use `cnnlConvolutionForward`.
/// It only supports 4D tensors.
struct ConvCnnl final : public Kernel {
struct {
DataType dt;
int xShape[4],
wShape[4],
yShape[4],
dilation[2],
pad[4],
stride[2];
std::optional<ExpandInfoCnnl> biasExpand;
} info;

explicit ConvCnnl(decltype(info)) noexcept;

static KernelBox build(PoolAttributes const &,
Tensor const &,
Tensor const &,
std::optional<std::reference_wrapper<Tensor const>>,
Tensor const &);
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
#ifdef USE_BANG
RoutineWorkspace lower(Resources &) const final;
#endif
};

}// namespace refactor::kernel

#endif// KERNEL_CONV_CNNL_KERNEL_HH
Loading