diff --git a/CMakeLists.txt b/CMakeLists.txt index 49ddcda6..a0e853b0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,4 +106,5 @@ add_subdirectory(src/06frontend) add_subdirectory(src/07onnx) add_subdirectory(src/08communication) add_subdirectory(src/08-01llm) +add_subdirectory(src/08-02moe) add_subdirectory(src/09python_ffi) diff --git a/src/04kernel/cuda/include/kernel/cuda/topk.cuh b/src/04kernel/cuda/include/kernel/cuda/topk.cuh new file mode 100644 index 00000000..b06cfc00 --- /dev/null +++ b/src/04kernel/cuda/include/kernel/cuda/topk.cuh @@ -0,0 +1,19 @@ +#ifndef KERNEL_CUDA_TOPK_CUH +#define KERNEL_CUDA_TOPK_CUH + +#include "threads_distributer.cuh" + +namespace refactor::kernel::cuda { + + void launchTopK( + KernelLaunchParameters const ¶ms, + float const *data, float *dstVal, unsigned int *dstIdx, + unsigned int topk, + unsigned int stride_axis, + unsigned int stride_in_pre, + unsigned int stride_out_pre, + unsigned int size_axis); + +}// namespace refactor::kernel::cuda + +#endif// KERNEL_CUDA_TOPK_CUH diff --git a/src/04kernel/cuda/src/topk.cu b/src/04kernel/cuda/src/topk.cu new file mode 100644 index 00000000..6b247ead --- /dev/null +++ b/src/04kernel/cuda/src/topk.cu @@ -0,0 +1,103 @@ +#include "kernel/cuda/topk.cuh" +#include "macro.cuh" +#include +#include +#include + +namespace refactor::kernel::cuda { + +using PairType = thrust::pair; + +struct ComparePair { + __host__ __device__ + bool operator()(const PairType& a, const PairType& b) const { + return a.first > b.first; + } +}; + +/* + __device__ + void process_element(unsigned int n, float *__restrict__ dstVal, + uint32_t *__restrict__ dstIdx, + PairType *list, + uint32_t stride_axis, + uint32_t init_offset){ + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < n; + tid += step) { + uint32_t offset = init_offset + stride_axis * tid; + dstVal[offset] = list[tid].first; + dstIdx[offset] = list[tid].second; + } + } +*/ + + + + __global__ static void TopKKernel( + unsigned long long n, + float const *__restrict__ data, + float *__restrict__ dstVal, + uint32_t *__restrict__ dstIdx, + uint32_t topk, + uint32_t stride_axis, + uint32_t stride_in_pre, + uint32_t stride_out_pre, + unsigned int size) { + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < n; + tid += step) { + PairType *list = new PairType[size]; + + for(uint32_t i = 0; i < size; i++){ + uint32_t srcIdx = tid /stride_axis * stride_in_pre + tid % stride_axis + i * stride_axis; + + list[i] = PairType(data[srcIdx], i); + } + // thrust没有partial_sort算法,可尝试优化:分成size/topk组,每组取一个最大值 + thrust::sort(thrust::device, list, list + size, ComparePair()); + + + uint32_t init_offset = tid /stride_axis * stride_out_pre + tid % stride_axis; + for (uint32_t i = 0; i < topk; i++) + { + uint32_t offset = init_offset + stride_axis * i; + dstVal[offset] = list[i].first; + dstIdx[offset] = list[i].second; + } + + delete[] list; + } + } + + + + void launchTopK( + KernelLaunchParameters const ¶ms, + float const *data, float *dstVal, uint32_t *dstIdx, + uint32_t topk, + uint32_t stride_axis, + uint32_t stride_in_pre, + uint32_t stride_out_pre, + unsigned int size_axis) { + + TopKKernel<<< + params.gridSize, + params.blockSize, + 0, + reinterpret_cast(params.stream)>>>( + params.n, + (data), + (dstVal), + (dstIdx), + topk, + stride_axis, + stride_in_pre, + stride_out_pre, + size_axis); + + } + +}// namespace refactor::kernel::cuda diff --git a/src/04kernel/include/kernel/attributes/moe_info.h b/src/04kernel/include/kernel/attributes/moe_info.h new file mode 100644 index 00000000..3e46b505 --- /dev/null +++ b/src/04kernel/include/kernel/attributes/moe_info.h @@ -0,0 +1,24 @@ +#ifndef KERNEL_MOE_INFO_H +#define KERNEL_MOE_INFO_H + +#include "../tensor.h" + +namespace refactor::kernel { + + struct AssignPosInfo { + int64_t top, expert_num; + int64_t elementSize; + + AssignPosInfo(int64_t top, int64_t expert_num, Tensor const &gate); + }; + + struct ReorderInfo{ + bool scatter; + int64_t top; + int64_t blockNum, blockSize; + ReorderInfo(bool scatter, int64_t top, TensorRefs inputs); + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SPLIT_INFO_H diff --git a/src/04kernel/include/kernel/attributes/topk_info.h b/src/04kernel/include/kernel/attributes/topk_info.h new file mode 100644 index 00000000..5cfc5ee6 --- /dev/null +++ b/src/04kernel/include/kernel/attributes/topk_info.h @@ -0,0 +1,24 @@ +#ifndef KERNEL_TOPK_INFO_H +#define KERNEL_TOPK_INFO_H + +#include "../tensor.h" + +namespace refactor::kernel { + + struct TopKInfo { + struct Stride{ + dim_t axis, in_pre, out_pre; + }; + struct Size{ + dim_t axis, except_axis; + }; + uint32_t topk; + Stride stride; + Size size; + + TopKInfo(uint32_t topk, uint32_t axis, Tensor const &input); + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SPLIT_INFO_H diff --git a/src/04kernel/include/kernel/collectors/moe.h b/src/04kernel/include/kernel/collectors/moe.h new file mode 100644 index 00000000..258450dc --- /dev/null +++ b/src/04kernel/include/kernel/collectors/moe.h @@ -0,0 +1,29 @@ +#ifndef KERNEL_MOE_H +#define KERNEL_MOE_H + +#include "../collector.h" + +namespace refactor::kernel { + + struct AssignPosCollector final : public InfoCollector { + uint32_t topk,numExperts; + constexpr AssignPosCollector(decltype(_target) target, uint32_t topk, uint32_t numExperts) noexcept + : InfoCollector(target) ,topk(topk), numExperts(numExperts){} + + std::vector + filter(TensorRefs inputs, TensorRefs outputs) const final; + }; + + struct ReorderCollector final : public InfoCollector { + bool scatter; + int64_t topk; + constexpr ReorderCollector(decltype(_target) target, bool scatter, int64_t topk) noexcept + : InfoCollector(target) ,scatter(scatter), topk(topk){} + + std::vector + filter(TensorRefs inputs, TensorRefs outputs) const final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SPLIT_H diff --git a/src/04kernel/include/kernel/collectors/topk.h b/src/04kernel/include/kernel/collectors/topk.h new file mode 100644 index 00000000..c4d8490f --- /dev/null +++ b/src/04kernel/include/kernel/collectors/topk.h @@ -0,0 +1,20 @@ +#ifndef KERNEL_TOPK_H +#define KERNEL_TOPK_H + +#include "../collector.h" + +namespace refactor::kernel { + + struct TopKCollector final : public InfoCollector { + uint32_t topk, axis; + + constexpr TopKCollector(decltype(_target) target, uint32_t topk, uint32_t axis_) noexcept + : InfoCollector(target), topk(topk), axis(axis_) {} + + std::vector + filter(TensorRefs inputs, TensorRefs outputs) const final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SPLIT_H diff --git a/src/04kernel/src/attributes/moe_info.cc b/src/04kernel/src/attributes/moe_info.cc new file mode 100644 index 00000000..829c54ae --- /dev/null +++ b/src/04kernel/src/attributes/moe_info.cc @@ -0,0 +1,13 @@ +#include "kernel/attributes/moe_info.h" +#include + +namespace refactor::kernel { + +AssignPosInfo::AssignPosInfo(int64_t top, int64_t expert_num, Tensor const &gate):\ + top(top), expert_num(expert_num),elementSize(gate.elementsSize()){} + +ReorderInfo::ReorderInfo(bool scatter, int64_t top, TensorRefs inputs):\ + scatter(scatter), top(top),blockNum(inputs[1].get().elementsSize()), blockSize(inputs[0].get().strides()[0]){} + + +} diff --git a/src/04kernel/src/attributes/topk_info.cc b/src/04kernel/src/attributes/topk_info.cc new file mode 100644 index 00000000..532f385d --- /dev/null +++ b/src/04kernel/src/attributes/topk_info.cc @@ -0,0 +1,16 @@ +#include "kernel/attributes/topk_info.h" +#include + +namespace refactor::kernel { + +TopKInfo::TopKInfo(uint32_t topk, uint32_t axis, Tensor const &input){ + this->topk =topk; + auto tmpStride = axis == 0 ? 0 : input.strides()[axis - 1]; + this->stride = {input.strides()[axis],\ + tmpStride,\ + tmpStride/input.shape[axis]*topk}; + this->size = {input.shape[axis], \ + input.elementsSize()/input.shape[axis]}; +} + +} diff --git a/src/04kernel/src/collectors/moe.cc b/src/04kernel/src/collectors/moe.cc new file mode 100644 index 00000000..4906499e --- /dev/null +++ b/src/04kernel/src/collectors/moe.cc @@ -0,0 +1,51 @@ +#include "kernel/collectors/moe.h" +#include "../kernels/moe/cpu_kernel.hh" +#include "kernel/attributes/moe_info.h" + +namespace refactor::kernel { + + std::vector + AssignPosCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + AssignPosInfo info(topk, numExperts, inputs[0]); + std::vector ans; + switch (_target) { + case decltype(_target)::Cpu: + if (auto ptr = AssignPosCpu::build(info); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + //todo :暂时用cpu的实现 + case decltype(_target)::Nvidia: + if (auto ptr = AssignPosCpu::build(info); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + default: + UNREACHABLEX(void, "Unknown target"); + } + return ans; + } + + std::vector + ReorderCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + ReorderInfo info(scatter, topk, inputs); + std::vector ans; + switch (_target) { + case decltype(_target)::Cpu: + if (auto ptr = ReorderCpu::build(info); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + //todo :暂时用cpu的实现 + case decltype(_target)::Nvidia: + if (auto ptr = ReorderCpu::build(info); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + default: + UNREACHABLEX(void, "Unknown target"); + } + return ans; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/collectors/topk.cc b/src/04kernel/src/collectors/topk.cc new file mode 100644 index 00000000..91a97427 --- /dev/null +++ b/src/04kernel/src/collectors/topk.cc @@ -0,0 +1,30 @@ +#include "kernel/collectors/topk.h" +#include "../kernels/topk/cpu_kernel.hh" +#include "kernel/attributes/topk_info.h" +//#include "../kernels/topk/cuda_kernel.hh" + +namespace refactor::kernel { + + std::vector + TopKCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + TopKInfo info(topk, axis, inputs[0]); + std::vector ans; + switch (_target) { + case decltype(_target)::Cpu: + if (auto ptr = TopKCpu::build(info); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + //todo :暂时用cpu的实现 + case decltype(_target)::Nvidia: + if (auto ptr = TopKCpu::build(info); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + default: + UNREACHABLEX(void, "Unknown target"); + } + return ans; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/moe/cpu_kernel.cc b/src/04kernel/src/kernels/moe/cpu_kernel.cc new file mode 100644 index 00000000..ef4f77d9 --- /dev/null +++ b/src/04kernel/src/kernels/moe/cpu_kernel.cc @@ -0,0 +1,83 @@ +#include "cpu_kernel.hh" +#include +#include + +namespace refactor::kernel { + + AssignPosCpu::AssignPosCpu(AssignPosInfo info) noexcept + : Kernel(), info(std::move(info)) {} + + auto AssignPosCpu::build(AssignPosInfo info) noexcept -> KernelBox { + return std::make_unique(std::move(info)); + } + auto AssignPosCpu::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto AssignPosCpu::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto AssignPosCpu::description() const noexcept -> std::string_view { + return "Performing AssignPos operation on generic cpu"; + } + + auto AssignPosCpu::lower(Resources &) const noexcept -> RoutineWorkspace { + using namespace runtime; + return [info = this->info](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + auto gate = reinterpret_cast(inputs[0]); + + auto expert_cnt = reinterpret_cast(outputs[0]);//T + auto pos = reinterpret_cast(outputs[1]); + std::memset(expert_cnt, 0, info.expert_num); + for (size_t i = 0; i < info.elementSize; i ++){ + ASSERT (gate[i] >= 0 && gate[i] < info.expert_num, "gate exceeds expert idx scope!"); + expert_cnt[gate[i]] ++; + } + std::vector expert_accumlate; + expert_accumlate.assign(info.expert_num, 0); + for (size_t i=0; i KernelBox { + return std::make_unique(std::move(info)); + } + auto ReorderCpu::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto ReorderCpu::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto ReorderCpu::description() const noexcept -> std::string_view { + return "Performing scatter operation on generic cpu"; + } + + auto ReorderCpu::lower(Resources &) const noexcept -> RoutineWorkspace { + using namespace runtime; + return [info = this->info](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + auto input = reinterpret_cast(inputs[0]); + auto pos = reinterpret_cast(inputs[1]); + auto dstVal = reinterpret_cast(outputs[0]);//T + + for(size_t i = 0; i +#include + +namespace refactor::kernel { + using K = TopKCpu; + + K::TopKCpu(TopKInfo info) noexcept + : Kernel(), info(std::move(info)) {} + + auto K::build(TopKInfo info) noexcept -> KernelBox { + return std::make_unique(std::move(info)); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing topk operation on generic cpu"; + } + + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { + using namespace runtime; + return [info = this->info](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + auto src = reinterpret_cast(inputs[0]); + + auto dstVal = reinterpret_cast(outputs[0]);//T + auto dstIndex = reinterpret_cast(outputs[1]); + + + size_t M = info.size.except_axis; + size_t N = info.size.axis; + + for(size_t m = 0; m < M; m ++){ + using PairType = std::pair; + std::vector list; + for(size_t n = 0; n < N; n++){ + auto srcIdx = m /info.stride.axis * info.stride.in_pre + m % info.stride.axis + n * info.stride.axis; + list.push_back({src[srcIdx],n}); + } + //list.sort([](const PairType &a, const PairType &b)->bool{return a.first > b.first;}); + std::partial_sort(list.begin(), \ + list.begin() + info.topk, \ + list.end(), \ + [](const PairType &a, const PairType &b)->bool{return a.first > b.first;}); + + size_t offset = m /info.stride.axis * info.stride.out_pre + m % info.stride.axis; + std::for_each_n(list.begin(), (uint32_t)info.topk, + [&](auto &elem) { + dstVal[offset] = elem.first; + dstIndex[offset] = elem.second; + offset += info.stride.axis; + }); + } + }; + } +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/topk/cpu_kernel.hh b/src/04kernel/src/kernels/topk/cpu_kernel.hh new file mode 100644 index 00000000..75b2a4ce --- /dev/null +++ b/src/04kernel/src/kernels/topk/cpu_kernel.hh @@ -0,0 +1,23 @@ +#ifndef KERNEL_TOPK_CPU_KERNEL_HH +#define KERNEL_TOPK_CPU_KERNEL_HH + +#include "kernel/attributes/topk_info.h" +#include "kernel/kernel.h" + +namespace refactor::kernel { + + struct TopKCpu final : public Kernel { + TopKInfo info; + explicit TopKCpu(TopKInfo info) noexcept; + + static KernelBox build(TopKInfo info) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; + RoutineWorkspace lower(Resources &) const noexcept final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SPLIT_CPU_KERNEL_HH diff --git a/src/04kernel/src/kernels/topk/cuda_kernel.cc b/src/04kernel/src/kernels/topk/cuda_kernel.cc new file mode 100644 index 00000000..acfa4733 --- /dev/null +++ b/src/04kernel/src/kernels/topk/cuda_kernel.cc @@ -0,0 +1,57 @@ +#include "cuda_kernel.hh" + +#ifdef USE_CUDA +#include "kernel/cuda/threads_distributer.cuh" +#include "kernel/cuda/topk.cuh" +#include +#include +#include +#include +#endif + +namespace refactor::kernel { + using K = TopKCuda; + + K::TopKCuda(TopKInfo info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(TopKInfo info) noexcept -> KernelBox { +#ifndef USE_CUDA + return nullptr; +#endif + + return std::make_unique(std::move(info)); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing concat operation using CUDA"; + } + +#ifdef USE_CUDA + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { + //return [info = this->info](Resources &, void *workspace, void const *const *inputs, void *const *outputs){ + + //} + return [info = this->info, params = cuda::ThreadsDistributer()(info.size.except_axis)] + (Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + cuda::launchTopK( + params, + reinterpret_cast(inputs[0]), + reinterpret_cast(outputs[0]), + reinterpret_cast(outputs[1]), + info.topk, + info.stride.axis, + info.stride.in_pre, + info.stride.out_pre, + info.size.axis); + }; + } +#endif +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/topk/cuda_kernel.hh b/src/04kernel/src/kernels/topk/cuda_kernel.hh new file mode 100644 index 00000000..069bbd44 --- /dev/null +++ b/src/04kernel/src/kernels/topk/cuda_kernel.hh @@ -0,0 +1,26 @@ +#ifndef KERNEL_TOPK_CUDA_KERNEL_HH +#define KERNEL_TOPK_CUDA_KERNEL_HH + +#include "kernel/attributes/topk_info.h" +#include "kernel/kernel.h" + +namespace refactor::kernel { + + struct TopKCuda final : public Kernel { + TopKInfo info; + + explicit TopKCuda(TopKInfo) noexcept; + + static KernelBox build(TopKInfo) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_CUDA + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_TOPK_CUDA_KERNEL_HH diff --git a/src/04kernel/test/kernels/moe/test_cpu.cpp b/src/04kernel/test/kernels/moe/test_cpu.cpp new file mode 100644 index 00000000..2574d8b7 --- /dev/null +++ b/src/04kernel/test/kernels/moe/test_cpu.cpp @@ -0,0 +1,75 @@ +#include "../../../src/kernels/moe/cpu_kernel.hh" +#include +#include + +using namespace refactor; +using namespace kernel; + +TEST(kernel, AssignPosCpu) { + // build routine + //auto inputTensor = Tensor::share(DataType::F32, Shape{4, 1024}); + auto gate = Tensor::share(DataType::U32, Shape{8, 2}); + auto expert_cnt = Tensor::share(DataType::U32, Shape{4}); + auto pos = Tensor::share(DataType::U32, Shape{16}); + + auto kernel = AssignPosCpu::build(AssignPosInfo(2,4, *gate)); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // put input data + std::vector ins = {3,2, 0,1, 2,1, 1,3, 2,0, 1,3, 1,0, 1,2}; + std::vector out0(expert_cnt->elementsSize()); + std::vector out1(pos->elementsSize()); + + // inference + void const *inputs[]{ins.data()}; + void *outputs[]{out0.data(), out1.data()}; + routine(res, nullptr, inputs, outputs); + + // check + std::vector expectExpertCnt = {3,6,4,3}; + std::vector expectPos = {13,9,2, 14,12,10,6,5,3, 15,8,4,1, 11,7,0}; + //std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<> inputTensors{input, pos}; + TensorRefs inputs_; + inputs_.reserve(inputTensors.size()); + std::transform(inputTensors.begin(), inputTensors.end(), + std::back_inserter(inputs_), + [](auto const &it) { return std::cref(*it); }); + + auto kernel = ReorderCpu::build(ReorderInfo(true, top, inputs_)); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // put input data + std::vector ins0(input->elementsSize()); + std::iota(ins0.begin(), ins0.end(), 0); + std::vector ins1 = {13,9,2, 14,12,10,6,5,3, 15,8,4,1, 11,7,0}; + std::vector out(input->elementsSize() * top); + + // inference + void const *inputs[]{ins0.data(), ins1.data()}; + void *outputs[]{out.data()}; + routine(res, nullptr, inputs, outputs); + std::for_each(out.begin(), out.end(),[](const float &val){std::cout< +#include + +using namespace refactor; +using namespace kernel; + +TEST(kernel, TopKCpu) { + // build routine + auto inputTensor = Tensor::share(DataType::F32, Shape{3, 4}); + auto outputTensor0 = Tensor::share(DataType::F32, Shape{3, 3}); + auto outputTensor1 = Tensor::share(DataType::U32, Shape{3, 3}); + + auto kernel = TopKCpu::build(TopKInfo(3,1, *inputTensor)); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // put input data + std::vector ins(inputTensor->elementsSize()); + std::vector out0(outputTensor0->elementsSize()); + std::vector out1(outputTensor1->elementsSize()); + + std::iota(ins.begin(), ins.end(), 0); + // inference + void const *inputs[]{ins.data()}; + void *outputs[]{out0.data(), out1.data()}; + routine(res, nullptr, inputs, outputs); + + // check + std::vector expectVal = {3,2,1,7,6,5,11,10,9}; + std::vector expectIdx = {3,2,1,3,2,1,3,2,1}; + std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout<lower(res).routine; + // put input data + std::vector ins(inputTensor->elementsSize()); + std::vector out0(outputTensor0->elementsSize()); + std::vector out1(outputTensor1->elementsSize()); + + std::iota(ins.begin(), ins.end(), 0); + // inference + void const *inputs[]{ins.data()}; + void *outputs[]{out0.data(), out1.data()}; + routine(res, nullptr, inputs, outputs); + std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout< expectVal = {6,7,4,5,2,3,14,15,12,13,10,11}; + std::vector expectIdx = {3,3,2,2,1,1,3,3,2,2,1,1}; + + + for(size_t i=0;i< expectVal.size(); ++i){ + EXPECT_EQ(expectVal[i], out0[i]); + EXPECT_EQ(expectIdx[i], out1[i]); + } +} + +TEST(kernel, TopKCpu2) { + // build routine + auto inputTensor = Tensor::share(DataType::F32, Shape{2, 4, 2}); + auto outputTensor0 = Tensor::share(DataType::F32, Shape{1, 4, 2}); + auto outputTensor1 = Tensor::share(DataType::U32, Shape{1, 4, 2}); + + auto kernel = TopKCpu::build(TopKInfo(1,0, *inputTensor)); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // put input data + std::vector ins(inputTensor->elementsSize()); + std::vector out0(outputTensor0->elementsSize()); + std::vector out1(outputTensor1->elementsSize()); + + std::iota(ins.begin(), ins.end(), 0); + // inference + void const *inputs[]{ins.data()}; + void *outputs[]{out0.data(), out1.data()}; + routine(res, nullptr, inputs, outputs); + std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout< expectVal = {8,9,10,11,12,13,14,15}; + std::vector expectIdx = {1,1,1,1,1,1,1,1}; + + + for(size_t i=0;i< expectVal.size(); ++i){ + EXPECT_EQ(expectVal[i], out0[i]); + EXPECT_EQ(expectIdx[i], out1[i]); + } +} + + +TEST(kernel, TopKCpu3) { + // build routine + auto inputTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 2}); + auto outputTensor0 = Tensor::share(DataType::F32, Shape{1, 3, 2, 2}); + auto outputTensor1 = Tensor::share(DataType::U32, Shape{1, 3, 2, 2}); + + auto kernel = TopKCpu::build(TopKInfo(1,0, *inputTensor)); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // put input data + std::vector ins(inputTensor->elementsSize()); + std::vector out0(outputTensor0->elementsSize()); + std::vector out1(outputTensor1->elementsSize()); + + std::iota(ins.begin(), ins.end(), 0); + // inference + void const *inputs[]{ins.data()}; + void *outputs[]{out0.data(), out1.data()}; + routine(res, nullptr, inputs, outputs); + std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout< expectVal = {12, 13, 14, 15, 16, 17, 18, 19, 20,21, 22,23}; + std::vector expectIdx = {1,1,1,1,1,1,1,1,1,1,1,1}; + + + for(size_t i=0;i< expectVal.size(); ++i){ + EXPECT_EQ(expectVal[i], out0[i]); + EXPECT_EQ(expectIdx[i], out1[i]); + } +} + +TEST(kernel, TopKCpu4) { + // build routine + auto inputTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 2}); + auto outputTensor0 = Tensor::share(DataType::F32, Shape{2, 2, 2, 2}); + auto outputTensor1 = Tensor::share(DataType::U32, Shape{2, 2, 2, 2}); + + auto kernel = TopKCpu::build(TopKInfo(2,1, *inputTensor)); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // put input data + std::vector ins(inputTensor->elementsSize()); + std::vector out0(outputTensor0->elementsSize()); + std::vector out1(outputTensor1->elementsSize()); + + std::iota(ins.begin(), ins.end(), 0); + // inference + void const *inputs[]{ins.data()}; + void *outputs[]{out0.data(), out1.data()}; + routine(res, nullptr, inputs, outputs); + std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout< expectVal = {8, 9, 10, 11,4,5,6,7,20,21,22,23,16,17,18,19}; + std::vector expectIdx = {2,2,2,2,1,1,1,1,2,2,2,2,1,1,1,1}; + + + for(size_t i=0;i< expectVal.size(); ++i){ + EXPECT_EQ(expectVal[i], out0[i]); + EXPECT_EQ(expectIdx[i], out1[i]); + } +} + + +TEST(kernel, TopKCpu5) { + // build routine + auto inputTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 2}); + auto outputTensor0 = Tensor::share(DataType::F32, Shape{2, 3, 1, 2}); + auto outputTensor1 = Tensor::share(DataType::U32, Shape{2, 3, 1, 2}); + + auto kernel = TopKCpu::build(TopKInfo(1,2, *inputTensor)); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // put input data + std::vector ins(inputTensor->elementsSize()); + std::vector out0(outputTensor0->elementsSize()); + std::vector out1(outputTensor1->elementsSize()); + + std::iota(ins.begin(), ins.end(), 0); + // inference + void const *inputs[]{ins.data()}; + void *outputs[]{out0.data(), out1.data()}; + routine(res, nullptr, inputs, outputs); + std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout< expectVal = {2,3,6,7,10,11,14,15,18,19,22,23}; + std::vector expectIdx = {1,1,1,1,1,1,1,1,1,1,1,1}; + + + for(size_t i=0;i< expectVal.size(); ++i){ + EXPECT_EQ(expectVal[i], out0[i]); + EXPECT_EQ(expectIdx[i], out1[i]); + } +} + +TEST(kernel, TopKCpu6) { + // build routine + auto inputTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 2}); + auto outputTensor0 = Tensor::share(DataType::F32, Shape{2, 3, 2, 1}); + auto outputTensor1 = Tensor::share(DataType::U32, Shape{2, 3, 2, 1}); + + auto kernel = TopKCpu::build(TopKInfo(1,3, *inputTensor)); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // put input data + std::vector ins(inputTensor->elementsSize()); + std::vector out0(outputTensor0->elementsSize()); + std::vector out1(outputTensor1->elementsSize()); + + std::iota(ins.begin(), ins.end(), 0); + // inference + void const *inputs[]{ins.data()}; + void *outputs[]{out0.data(), out1.data()}; + routine(res, nullptr, inputs, outputs); + std::for_each(out0.begin(), out0.end(),[](const float &val){std::cout< expectVal = {1,3,5,7,9,11,13,15,17,19,21,23}; + std::vector expectIdx = {1,1,1,1,1,1,1,1,1,1,1,1}; + + + for(size_t i=0;i< expectVal.size(); ++i){ + EXPECT_EQ(expectVal[i], out0[i]); + EXPECT_EQ(expectIdx[i], out1[i]); + } +} \ No newline at end of file diff --git a/src/04kernel/test/kernels/topk/test_cuda.cpp b/src/04kernel/test/kernels/topk/test_cuda.cpp new file mode 100644 index 00000000..d5c252bd --- /dev/null +++ b/src/04kernel/test/kernels/topk/test_cuda.cpp @@ -0,0 +1,68 @@ +#ifdef USE_CUDA + +#include "../../../src/kernels/topk/cpu_kernel.hh" +#include "../../../src/kernels/topk/cuda_kernel.hh" +#include "hardware/device_manager.h" +#include +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, TopKCuda) { + // build routine + auto inputTensor = Tensor::share(DataType::F32, Shape{3, 4}); + std::vector> outputTensors{ + Tensor::share(DataType::F32, Shape{3, 3}), + Tensor::share(DataType::U32, Shape{3, 3})}; + + auto kCpu = TopKCpu::build(TopKInfo(3,1, *inputTensor)); + auto kCuda = TopKCuda::build(TopKInfo(3,1, *inputTensor)); + ASSERT_TRUE(kCpu); + ASSERT_TRUE(kCuda); + auto res = runtime::Resources(); + auto rCpu = kCpu->lower(res).routine; + auto rCuda = kCuda->lower(res).routine; + + // device malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + Arc + gpuIn = dev.malloc(inputTensor->bytesSize()), + gpuOuts[]{ + dev.malloc(outputTensors[0]->bytesSize()), + dev.malloc(outputTensors[1]->bytesSize()), + }; + // put input data + std::vector data(inputTensor->elementsSize()); + + std::vector outCpu1(outputTensors[0]->elementsSize()); + std::vector outCpu2(outputTensors[1]->elementsSize()); + + + std::vector out1(outputTensors[0]->elementsSize()); + std::vector out2(outputTensors[1]->elementsSize()); + + std::iota(data.begin(), data.end(), 0); + gpuIn->copyFromHost(data.data(), inputTensor->bytesSize()); + // inference + { + void const *inputs[]{*gpuIn}; + void *outputs[]{*gpuOuts[0], *gpuOuts[1]}; + rCuda(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{data.data()}; + void *outputs[]{outCpu1.data(), outCpu2.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // check + + gpuOuts[0]->copyToHost(out1.data(), outputTensors[0]->bytesSize()); + EXPECT_EQ(out1, outCpu1); + gpuOuts[1]->copyToHost(out2.data(), outputTensors[1]->bytesSize()); + EXPECT_EQ(out2, outCpu2); + +} + +#endif diff --git a/src/05computation/include/computation/operators/moe.h b/src/05computation/include/computation/operators/moe.h new file mode 100644 index 00000000..775c9fd7 --- /dev/null +++ b/src/05computation/include/computation/operators/moe.h @@ -0,0 +1,37 @@ +#ifndef COMPUTATION_MOE_H +#define COMPUTATION_MOE_H + +#include "../operator.h" + +namespace refactor::computation { + + struct AssignPos final : public Operator { + int64_t topk,numExperts; + + constexpr explicit AssignPos(int64_t topk, int64_t numExperts) noexcept : Operator(), + topk(topk), numExperts(numExperts){} + + static size_t typeId() noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const final; + std::string serialize() const noexcept final; + }; + + struct Reorder final : public Operator { + bool scatter; + int64_t topk; + + constexpr explicit Reorder(bool scatter, int64_t topk) noexcept : Operator(), + scatter(scatter), topk(topk){} + + static size_t typeId() noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const final; + std::string serialize() const noexcept final; + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_RMS_NORMALIZATION_H diff --git a/src/05computation/include/computation/operators/topk.h b/src/05computation/include/computation/operators/topk.h new file mode 100644 index 00000000..d5c401f4 --- /dev/null +++ b/src/05computation/include/computation/operators/topk.h @@ -0,0 +1,20 @@ +#ifndef COMPUTATION_TOPK_H +#define COMPUTATION_TOPK_H + +#include "../operator.h" + +namespace refactor::computation { + + struct TopK final : public Operator { + uint32_t topk, axis; + constexpr TopK(uint32_t topk, uint32_t axis) noexcept : topk(topk), axis(axis){} + + static size_t typeId() noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const noexcept final; + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_SPLIT_H diff --git a/src/05computation/src/operators/moe.cc b/src/05computation/src/operators/moe.cc new file mode 100644 index 00000000..a62d4abf --- /dev/null +++ b/src/05computation/src/operators/moe.cc @@ -0,0 +1,34 @@ +#include "computation/operators/moe.h" +#include "kernel/collectors/moe.h" + +namespace refactor::computation { + + auto AssignPos::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto AssignPos::opTypeId() const noexcept -> size_t { return typeId(); } + auto AssignPos::name() const noexcept -> std::string_view { return "moe::AssignPos"; } + auto AssignPos::candidateKernels(Target target) const -> kernel::CollectorBox { + using Collector_ = kernel::AssignPosCollector; + return std::make_unique(target, topk, numExperts); + } + auto AssignPos::serialize() const noexcept -> std::string { + return "moe::AssignPos()"; + } + + auto Reorder::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto Reorder::opTypeId() const noexcept -> size_t { return typeId(); } + auto Reorder::name() const noexcept -> std::string_view { return "moe::Reorder"; } + auto Reorder::candidateKernels(Target target) const -> kernel::CollectorBox { + using Collector_ = kernel::ReorderCollector; + return std::make_unique(target, scatter, topk); + } + auto Reorder::serialize() const noexcept -> std::string { + return "moe::Reorder()"; + } + +}// namespace refactor::computation diff --git a/src/05computation/src/operators/topk.cc b/src/05computation/src/operators/topk.cc new file mode 100644 index 00000000..f25e9a85 --- /dev/null +++ b/src/05computation/src/operators/topk.cc @@ -0,0 +1,17 @@ +#include "computation/operators/topk.h" +#include "kernel/collectors/topk.h" + +namespace refactor::computation { + + size_t TopK::typeId() noexcept { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + size_t TopK::opTypeId() const noexcept { return typeId(); } + std::string_view TopK::name() const noexcept { return "TopK"; } + auto TopK::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { + using Collector_ = kernel::TopKCollector; + return std::make_unique(target, topk, axis); + } + +}// namespace refactor::computation diff --git a/src/07onnx/src/operators.cpp b/src/07onnx/src/operators.cpp index 0981f720..8e50a810 100644 --- a/src/07onnx/src/operators.cpp +++ b/src/07onnx/src/operators.cpp @@ -38,6 +38,7 @@ #include "operators/transpose.hh" #include "operators/unsqueeze.hh" #include "operators/where.hh" +#include "operators/topk.hh" namespace refactor::onnx { @@ -131,6 +132,7 @@ namespace refactor::onnx { REGISTER(Where , Where ); REGISTER(HardSigmoid , HardSigmoid ); REGISTER(Pad , Pad ); + REGISTER(TopK , TopK ); // clang-format on #undef REGISTER } diff --git a/src/07onnx/src/operators/topk.cc b/src/07onnx/src/operators/topk.cc new file mode 100644 index 00000000..c1e908e6 --- /dev/null +++ b/src/07onnx/src/operators/topk.cc @@ -0,0 +1,55 @@ +#include "common.h" +#include "topk.hh" +#include "computation/operators/topk.h" +#include + +namespace refactor::onnx { + using Op = TopK; + + Op::TopK(Int topk, Int axis):Operator(), topk(topk), axis(axis){} + + auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { + auto axis = attributes["axis"].int_(); + auto topk = attributes["topk"].int_(); + return OpBox(std::make_unique(topk, axis)); + } + + auto Op::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto Op::opTypeId() const -> size_t { return typeId(); } + auto Op::opTypeName() const -> std::string_view { return "TopK"; } + + auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult { + if (inputs.empty() || inputs.size() >= 2) { + return Err(InferError(ERROR_MSG("Input size error"))); + } + auto const &input = inputs[0]; + auto rank = input.rank(); + auto axis_ = axis < 0 ? axis + rank : axis; + if (rank <= axis_) { + return Err(InferError(ERROR_MSG("axis error"))); + } + if (topk < 0 || topk > input.shape[axis_].value()){ + return Err(InferError(ERROR_MSG("topk error"))); + } + + Tensors ans(2, nullptr); + auto dependencies = extractDependency(inputs); + ans[0] = Tensor::share(input.dataType, input.shape, dependencies); + ans[0]->shape[axis_] = DimExpr(topk); + ans[1] = Tensor::share(DataType::U32, input.shape, dependencies); + ans[1]->shape[axis_] = DimExpr(topk); + return Ok(Tensors{std::move(ans)}); + } + + auto Op::lower(TensorRefs inputs) const -> computation::OpBox { + using Op_ = computation::TopK; + auto rank = inputs[0].rank(); + auto axis_ = axis < 0 ? axis + rank : axis; + return std::make_unique(topk, axis_); + } + +}// namespace refactor::onnx diff --git a/src/07onnx/src/operators/topk.hh b/src/07onnx/src/operators/topk.hh new file mode 100644 index 00000000..2b86f5bb --- /dev/null +++ b/src/07onnx/src/operators/topk.hh @@ -0,0 +1,23 @@ +#ifndef ONNX_TOPK_HH +#define ONNX_TOPK_HH + +#include "frontend/operator.h" + +namespace refactor::onnx { + using namespace frontend; + + struct TopK final : public Operator { + Int topk, axis; + TopK(Int topk, Int axis); + + static size_t typeId(); + static OpBox build(ModelContext const &, std::string_view, Attributes); + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + computation::OpBox lower(TensorRefs) const final; + }; + +}// namespace refactor::onnx + +#endif// ONNX_WHERE_HH diff --git a/src/08-02moe/CMakeLists.txt b/src/08-02moe/CMakeLists.txt new file mode 100644 index 00000000..25b882cc --- /dev/null +++ b/src/08-02moe/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.12 FATAL_ERROR) +project(moe VERSION 0.0.0 LANGUAGES CXX) +message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION}) + +file(GLOB_RECURSE MOE_SRC src/*.cc src/*.cpp) +add_library(moe STATIC ${MOE_SRC}) +target_link_libraries(moe PUBLIC frontend) +target_include_directories(moe PUBLIC include) + +file(GLOB_RECURSE MOE_TEST test/*.cpp) +if(MOE_TEST) + add_executable(moe_test ${MOE_TEST}) + add_test(moe_test moe_test) + target_link_libraries(moe_test moe GTest::gtest_main Backward::Object) +endif() diff --git a/src/08-02moe/include/moe/operators.h b/src/08-02moe/include/moe/operators.h new file mode 100644 index 00000000..b4221025 --- /dev/null +++ b/src/08-02moe/include/moe/operators.h @@ -0,0 +1,10 @@ +#ifndef MOE_OPERATORS_H +#define MOE_OPERATORS_H + +namespace refactor::moe { + + void register_(); + +}// namespace refactor::moe + +#endif// MOE_OPERATORS_H diff --git a/src/08-02moe/src/operators.cpp b/src/08-02moe/src/operators.cpp new file mode 100644 index 00000000..5db39632 --- /dev/null +++ b/src/08-02moe/src/operators.cpp @@ -0,0 +1,16 @@ +#include "moe/operators.h" +#include "operators/moe.hh" + +namespace refactor::moe { + using namespace frontend; + + void register_() { +#define REGISTER(NAME, CLASS) Operator::register_("moe::" #NAME) + // clang-format off + REGISTER(AssignPos , AssignPos ); + REGISTER(Reorder , Reorder ); + // clang-format on +#undef REGISTER + } + +}// namespace refactor::moe diff --git a/src/08-02moe/src/operators/moe.cc b/src/08-02moe/src/operators/moe.cc new file mode 100644 index 00000000..68bfeff5 --- /dev/null +++ b/src/08-02moe/src/operators/moe.cc @@ -0,0 +1,83 @@ +#include "moe.hh" +#include "common.h" +#include "computation/operators/moe.h" + +namespace refactor::moe { + + AssignPos::AssignPos(Int topk, Int numExperts) : Operator() ,topk(topk), numExperts(numExperts){} + + auto AssignPos::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { + auto topk = attributes["topk"].int_(); + auto num_experts = attributes["num_experts"].int_(); + return OpBox(std::make_unique(topk, num_experts)); + } + auto AssignPos::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto AssignPos::opTypeId() const -> size_t { return typeId(); } + auto AssignPos::opTypeName() const -> std::string_view { return "moe::AssignPos"; } + + auto AssignPos::infer(TensorRefs inputs, InferOptions const &) const -> InferResult { + EXPECT_SIZE(1) + + auto const &gate = inputs[0]; + if(topk < 0 || numExperts < 0 || topk > numExperts){ + return Err(InferError(ERROR_MSG("topk or numExperts is error"))); + } + if (gate.dataType != DataType::I64) { + return Err(InferError(ERROR_MSG("Input data type not support"))); + } + + return Ok(Tensors{Tensor::share(gate.dataType, Shape{DimExpr(numExperts)}, extractDependency(inputs)), + Tensor::share(gate.dataType, gate.shape, extractDependency(inputs))}); + } + + auto AssignPos::lower(TensorRefs) const -> computation::OpBox { + using Op_ = computation::AssignPos; + return std::make_unique(topk, numExperts); + } + + Reorder::Reorder(bool scatter, Int topk, Int dim) : Operator() ,scatter(scatter), top(topk), dim(dim){} + + auto Reorder::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { + auto topk = attributes["topk"].int_(); + bool scatter = attributes["scatter"].int_() != 0 ; + bool dim = attributes["dim"].int_(); + return OpBox(std::make_unique(scatter, topk, dim)); + } + auto Reorder::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto Reorder::opTypeId() const -> size_t { return typeId(); } + auto Reorder::opTypeName() const -> std::string_view { return "moe::Reorder"; } + + auto Reorder::infer(TensorRefs inputs, InferOptions const &) const -> InferResult { + EXPECT_SIZE(2) + auto const &input = inputs[0]; + auto const &pos = inputs[1]; + if (dim != 0) + return Err(InferError(ERROR_MSG("dim is not right!"))); + if(top < 0 ){ + return Err(InferError(ERROR_MSG("topkis error"))); + } + if(scatter && input.elementsSize()/input.shape[input.shape.size()-1].value() * top != pos.elementsSize()) + return Err(InferError(ERROR_MSG("Inputs data size are not right!"))); + else if(!scatter && input.elementsSize()/input.shape[input.shape.size()-1].value() != pos.elementsSize()) + return Err(InferError(ERROR_MSG("Inputs data size are not right!"))); + + if (pos.dataType != DataType::I64) { + return Err(InferError(ERROR_MSG("Input data type not support"))); + } + + return Ok(Tensors{Tensor::share(input.dataType, pos.shape, extractDependency(inputs))}); + } + + auto Reorder::lower(TensorRefs) const -> computation::OpBox { + using Op_ = computation::Reorder; + return std::make_unique(scatter, top); + } +}// namespace refactor::llm diff --git a/src/08-02moe/src/operators/moe.hh b/src/08-02moe/src/operators/moe.hh new file mode 100644 index 00000000..42501bdb --- /dev/null +++ b/src/08-02moe/src/operators/moe.hh @@ -0,0 +1,38 @@ +#ifndef MOE_HH +#define MOE_HH + +#include "frontend/operator.h" + +namespace refactor::moe { + using namespace frontend; + + struct AssignPos final : public Operator { + Int topk, numExperts; + explicit AssignPos(Int topk, Int numExperts); + + static OpBox build(ModelContext const &, std::string_view, Attributes); + static size_t typeId(); + + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + computation::OpBox lower(TensorRefs) const final; + }; + + struct Reorder final : public Operator { + bool scatter; + Int top, dim; + explicit Reorder(bool scatter, Int topk, Int dim); + + static OpBox build(ModelContext const &, std::string_view, Attributes); + static size_t typeId(); + + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + computation::OpBox lower(TensorRefs) const final; + }; + +}// namespace refactor::llm + +#endif// LLM_RMS_ATTENTION_HH diff --git a/src/08-02moe/test/test_moe.cpp b/src/08-02moe/test/test_moe.cpp new file mode 100644 index 00000000..c735d803 --- /dev/null +++ b/src/08-02moe/test/test_moe.cpp @@ -0,0 +1,25 @@ +#include "../src/operators/moe.hh" +#include "moe/operators.h" +#include + +using namespace refactor; +using namespace moe; + +TEST(infer, AssignPos) { + moe::register_(); + auto edges = Edges{ + + {Tensor::share(DataType::I64, Shape{DimExpr(8), DimExpr(2)}, {}), ""},//gate 8*2 + }; + count_t inputs[]{0}; + auto infered = AssignPos(2,4).infer(TensorRefs(edges, inputs), {true}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 2); + auto expert_cnt = std::move(outputs[0]); + ASSERT_EQ(expert_cnt->dataType, DataType::F32); + ASSERT_EQ(expert_cnt->shape, (Shape{DimExpr(4)})); + auto pos = std::move(outputs[1]); + ASSERT_EQ(pos->dataType, DataType::I64); + ASSERT_EQ(pos->shape, (Shape{DimExpr(16)})); +} diff --git a/src/09python_ffi/CMakeLists.txt b/src/09python_ffi/CMakeLists.txt index ccce34d3..50fc535a 100644 --- a/src/09python_ffi/CMakeLists.txt +++ b/src/09python_ffi/CMakeLists.txt @@ -7,7 +7,7 @@ add_subdirectory(pybind11) file(GLOB_RECURSE PYFFI_SRC src/*.cc src/*.cpp) pybind11_add_module(python_ffi SHARED ${PYFFI_SRC}) -target_link_libraries(python_ffi PRIVATE onnx llm communication) +target_link_libraries(python_ffi PRIVATE onnx llm communication moe) target_include_directories(python_ffi PRIVATE include) # EXAMPLE_VERSION_INFO is defined by setup.py and passed into the C++ code as a diff --git a/src/09python_ffi/src/main.cpp b/src/09python_ffi/src/main.cpp index 48a4ea6f..d2d2c8ce 100644 --- a/src/09python_ffi/src/main.cpp +++ b/src/09python_ffi/src/main.cpp @@ -3,6 +3,7 @@ #include "import.h" #include "llm/operators.h" #include "onnx/operators.h" +#include "moe/operators.h" #include // keep this line to convert stl types namespace py = pybind11; @@ -17,6 +18,7 @@ namespace refactor::python_ffi { onnx::register_(); llm::register_(); communication::register_(); + moe::register_(); // clang-format off