Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: mlu跑通llama,但未得到正确结果
  • Loading branch information
Chamberlain0w0 authored and YdrMaster committed Jan 31, 2024
commit 5aa7a1e6d8edd320de298d99ad486d4a440e56ad
4 changes: 3 additions & 1 deletion src/04kernel/src/collectors/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ namespace refactor::kernel {
ConcatCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
SplitInfo info(axis, inputs);

auto const &b = outputs[0];

std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
Expand All @@ -22,7 +24,7 @@ namespace refactor::kernel {
}
break;
case decltype(_target)::Mlu:
if (auto ptr = ConcatCnnl::build(axis, inputs, outputs[0].get()); ptr) {
if (auto ptr = ConcatCnnl::build(axis, inputs, b); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
Expand Down
9 changes: 7 additions & 2 deletions src/04kernel/src/collectors/gather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@ namespace refactor::kernel {
GatherCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
GatherInfo info(axis, inputs[0], inputs[1]);

std::vector<KernelBox> ans;
auto const &a = inputs[0];
auto const &b = inputs[1];
auto const &c = outputs[0];

std::vector<KernelBox>
ans;
switch (_target) {
case decltype(_target)::Cpu:
if (auto ptr = GatherCpu::build(info); ptr != nullptr) {
Expand All @@ -22,7 +27,7 @@ namespace refactor::kernel {
}
break;
case decltype(_target)::Mlu:
if (auto ptr = GatherCnnl::build(axis, inputs[0].get(), inputs[1].get(), outputs[0].get()); ptr != nullptr) {
if (auto ptr = GatherCnnl::build(axis, a, b, c); ptr != nullptr) {
ans.emplace_back(std::move(ptr));
}
break;
Expand Down
4 changes: 3 additions & 1 deletion src/04kernel/src/collectors/split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ namespace refactor::kernel {
SplitCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
SplitInfo info(axis, outputs);

auto const &a = inputs[0];

std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
Expand All @@ -22,7 +24,7 @@ namespace refactor::kernel {
}
break;
case decltype(_target)::Mlu:
if (auto ptr = SplitCnnl::build(axis, inputs[0].get(), outputs); ptr) {
if (auto ptr = SplitCnnl::build(axis, a, outputs); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
Expand Down
2 changes: 1 addition & 1 deletion src/04kernel/src/kernels/concat/cnnl_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace refactor::kernel {
K::ConcatCnnl(SplitInfoCnnl info_) noexcept
: Kernel(), info(std::move(info_)) {}

auto K::build(int axis, TensorRefs inputs, Tensor output) noexcept -> KernelBox {
auto K::build(int axis, TensorRefs inputs, Tensor const &output) noexcept -> KernelBox {
#ifndef USE_BANG
return nullptr;
#endif
Expand Down
2 changes: 1 addition & 1 deletion src/04kernel/src/kernels/concat/cnnl_kernel.hh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace refactor::kernel {

explicit ConcatCnnl(SplitInfoCnnl) noexcept;

static KernelBox build(int, TensorRefs, Tensor) noexcept;
static KernelBox build(int, TensorRefs, Tensor const &) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
Expand Down
5 changes: 3 additions & 2 deletions src/04kernel/src/kernels/gather/cnnl_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ namespace refactor::kernel {
K::GatherCnnl(decltype(info) info_) noexcept
: Kernel(), info(std::move(info_)) {}

auto K::build(int axis, Tensor input, Tensor index, Tensor output) noexcept -> KernelBox {
auto K::build(int axis, Tensor const &input, Tensor const &index, Tensor const &output) noexcept -> KernelBox {
#ifndef USE_BANG
return nullptr;
#endif

return std::make_unique<K>(decltype(info){
input.dataType,
index.dataType,
DataType::I32,
axis,
std::vector<int>(input.shape.begin(), input.shape.end()),
std::vector<int>(index.shape.begin(), index.shape.end()),
Expand Down
2 changes: 1 addition & 1 deletion src/04kernel/src/kernels/gather/cnnl_kernel.hh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace refactor::kernel {

explicit GatherCnnl(decltype(info)) noexcept;

static KernelBox build(int, Tensor, Tensor, Tensor) noexcept;
static KernelBox build(int, Tensor const &, Tensor const &, Tensor const &) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
Expand Down
6 changes: 4 additions & 2 deletions src/04kernel/src/kernels/reduce/cnnl_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ namespace refactor::kernel {
for (auto axis : axes) {
dimsO[axis] = 1;
}
setCnnlTensor(d->x, dataType, slice(dimsI.data(), dimsI.size()));
setCnnlTensor(d->y, dataType, slice(dimsO.data(), dimsO.size()));
// setCnnlTensor(d->x, dataType, slice(dimsI.data(), dimsI.size()));
// setCnnlTensor(d->y, dataType, slice(dimsO.data(), dimsO.size()));
CNNL_ASSERT(cnnlSetTensorDescriptor(d->x, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(dataType), dimsI.size(), dimsI.data()));
CNNL_ASSERT(cnnlSetTensorDescriptor(d->y, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(dataType), dimsO.size(), dimsO.data()));

// clang-format off
auto reduceOp = reduceType == ReduceType::Mean ? CNNL_REDUCE_AVG
Expand Down
25 changes: 6 additions & 19 deletions src/04kernel/src/kernels/simple_binary/binary_cnnl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ namespace refactor::kernel {
// !a.dataType.isFloat() ||
!ARTHIMETIC.contains(op) ||
// At least one of a,b should have the same shape as c
(a.shape != c.shape && b.shape != c.shape) ||
(a.shape != c.shape && b.shape != c.shape)
// Sub only supports brocasting b
(a.shape != c.shape && op == Op::Sub)) {
// (a.shape != c.shape && op == Op::Sub)
) {
return nullptr;
}

Expand Down Expand Up @@ -122,18 +123,13 @@ namespace refactor::kernel {

auto handle = res.fetchOrStore<CnnlContext>()->handle;
size_t workspaceSize;
if (aDims != cDims) {
CNNL_ASSERT(cnnlGetBinaryWorkspaceSize(handle, d->bDesc,
d->aDesc, d->cDesc,
&workspaceSize));
} else {
CNNL_ASSERT(cnnlGetBinaryWorkspaceSize(handle, d->aDesc,
CNNL_ASSERT(cnnlGetBinaryWorkspaceSize(handle, d->aDesc,
d->bDesc, d->cDesc,
&workspaceSize));
}


res.fetchOrStore<CnnlContext>();
auto routine = [swap = aDims != cDims, d,
auto routine = [d = std::move(d),
workspaceSize, cnnlLogicOP,
op = this->opType](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
auto handle = res.fetchOrStore<CnnlContext>()->handle;
Expand All @@ -151,20 +147,11 @@ namespace refactor::kernel {
beta = d->f32
? factor<fp32_t>(0)
: factor<fp64_t>(0);

if (swap) {
CNNL_ASSERT(cnnlOpTensor(handle, d->opDesc,
&alphaB, d->bDesc, b,
&alphaA, d->aDesc, a,
workspace, workspaceSize,
&beta, d->cDesc, c));
} else {
CNNL_ASSERT(cnnlOpTensor(handle, d->opDesc,
&alphaA, d->aDesc, a,
&alphaB, d->bDesc, b,
workspace, workspaceSize,
&beta, d->cDesc, c));
}
} else if (op == SimpleBinaryType::Div) {
CNNL_ASSERT(cnnlDiv_v2(handle,
CNNL_COMPUTATION_HIGH_PRECISION,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ namespace refactor::kernel {

setCnnlTensor(d->tensor, dataType, slice(&size, 1));

auto cnnlUnaryForward = [this](cnnlHandle_t handle,
const cnnlTensorDescriptor_t x_desc,
const void *x,
const cnnlTensorDescriptor_t y_desc,
void *y) -> cnnlStatus_t {
switch (this->type) {
auto cnnlUnaryForward = [t = this->type](cnnlHandle_t handle,
const cnnlTensorDescriptor_t x_desc,
const void *x,
const cnnlTensorDescriptor_t y_desc,
void *y) -> cnnlStatus_t {
switch (t) {
case Ty::Abs:
return cnnlAbs(handle, x_desc, x, y_desc, y);
case Ty::Neg:
Expand All @@ -77,6 +77,7 @@ namespace refactor::kernel {
case Ty::Erf:
return cnnlErf_v2(handle, CNNL_COMPUTATION_HIGH_PRECISION, x_desc, x, y_desc, y);
default:
// fmt::println("{}", unaryName(t));
UNREACHABLE();
}
};
Expand Down
11 changes: 7 additions & 4 deletions src/04kernel/src/kernels/split/cnnl_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace refactor::kernel {
: dataType(dt_), axis(axis_), num(num_), inDim(std::move(in_)), outDims(std::move(out_)) {}


Info::SplitInfoCnnl(int axis, Tensor input, TensorRefs outputs)
Info::SplitInfoCnnl(int axis, Tensor const &input, TensorRefs outputs)
: SplitInfoCnnl(input.dataType, axis, outputs.size(),
std::move(std::vector<int>(input.shape.begin(), input.shape.end())),
std::move([](TensorRefs tensors) -> std::vector<std::vector<int>> {
Expand All @@ -29,7 +29,7 @@ namespace refactor::kernel {
K::SplitCnnl(SplitInfoCnnl info_) noexcept
: Kernel(), info(std::move(info_)) {}

auto K::build(int axis, Tensor input, TensorRefs outputs) noexcept -> KernelBox {
auto K::build(int axis, Tensor const &input, TensorRefs outputs) noexcept -> KernelBox {
#ifndef USE_BANG
return nullptr;
#endif
Expand Down Expand Up @@ -78,9 +78,12 @@ namespace refactor::kernel {
Descriptors(Descriptors &&) = delete;
};
auto d = std::make_shared<Descriptors>(info.num, info.dataType != DT::F64);
setCnnlTensor(d->in, info.dataType, slice(info.inDim.data(), info.inDim.size()));
// setCnnlTensor(d->in, info.dataType, slice(info.inDim.data(), info.inDim.size()));
CNNL_ASSERT(cnnlSetTensorDescriptor(d->in, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(info.dataType), info.inDim.size(), info.inDim.data()));

for (auto i = 0; i < info.outDims.size(); i++) {
setCnnlTensor(d->out[i], info.dataType, slice(info.outDims[i].data(), info.outDims[i].size()));
// setCnnlTensor(d->out[i], info.dataType, slice(info.outDims[i].data(), info.outDims[i].size()));
CNNL_ASSERT(cnnlSetTensorDescriptor(d->out[i], CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(info.dataType), info.outDims[i].size(), info.outDims[i].data()));
}

auto handle = res.fetchOrStore<CnnlContext>()->handle;
Expand Down
4 changes: 2 additions & 2 deletions src/04kernel/src/kernels/split/cnnl_kernel.hh
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ namespace refactor::kernel {
std::vector<std::vector<int>> outDims;

SplitInfoCnnl(DataType, int, int, std::vector<int>, std::vector<std::vector<int>>);
SplitInfoCnnl(int, Tensor, TensorRefs);
SplitInfoCnnl(int, Tensor const &, TensorRefs);
};

struct SplitCnnl final : public Kernel {
SplitInfoCnnl info;

explicit SplitCnnl(SplitInfoCnnl) noexcept;

static KernelBox build(int, Tensor, TensorRefs) noexcept;
static KernelBox build(int, Tensor const &, TensorRefs) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
Expand Down
1 change: 1 addition & 0 deletions src/09python_ffi/src/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ namespace refactor::python_ffi {
// clang-format off
auto target_ = target == "cpu" ? Target::Cpu
: target == "cuda" ? Target::Nvidia
: target == "mlu" ? Target::Mlu
: UNREACHABLEX(Target, "Unknown target: {}", target);
// clang-format on
return compileOn(hardware::device::fetch(target_),
Expand Down
1 change: 1 addition & 0 deletions src/09python_ffi/src/import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace refactor::python_ffi {
// clang-format off
auto type_ = type == "cpu" ? Device::Type::Cpu
: type == "nvidia" ? Device::Type::Nvidia
: type == "mlu" ? Device::Type::Mlu
: UNREACHABLEX(Device::Type, "Unknown device type: \"{}\"", type);
// clang-format on
return device::init(type_, card, "");
Expand Down