Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
compile error
  • Loading branch information
xgqdut2016 committed Mar 5, 2024
commit 6981bafb8e316aeced7d9a7ecbecf8d24168ad33
3 changes: 3 additions & 0 deletions src/04kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ if(USE_CUDA)
file(GLOB_RECURSE KERNEL_CUDA_SRC src/*.cu)
add_subdirectory(cuda)
endif()
if(USE_BANG)
file(GLOB_RECURSE KERNEL_BANG_SRC src/*.mlu)
endif()

add_library(kernel STATIC ${KERNEL_SRC} ${KERNEL_CUDA_SRC})
target_link_libraries(kernel PUBLIC runtime)
Expand Down
4 changes: 4 additions & 0 deletions src/04kernel/src/collectors/softmax.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "kernel/collectors/softmax.h"
#include "../kernels/softmax/bang_kernel.hh"
#include "../kernels/softmax/cnnl_kernel.hh"
#include "../kernels/softmax/cpu_kernel.hh"
#include "../kernels/softmax/cuda_kernel.hh"
Expand Down Expand Up @@ -33,6 +34,9 @@ namespace refactor::kernel {
if (auto ptr = SoftmaxCnnl::build(cnnl::SoftmaxAlgo::ACCURATE, info); ptr) {
ans.emplace_back(std::move(ptr));
}
if (auto ptr = SoftmaxBang::build(info); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
}
default:
Expand Down
29 changes: 29 additions & 0 deletions src/04kernel/src/kernels/softmax/bang_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "bang_kernel.hh"

namespace refactor::kernel {
using K = SoftmaxBang;

K::SoftmaxBang(SoftmaxInfo info_) noexcept
: Kernel(), info(std::move(info_)) {}

auto K::build(SoftmaxInfo info) noexcept -> KernelBox {
#ifndef USE_BANG
return nullptr;
#endif

return info.type.isFloat()
? std::make_unique<K>(std::move(info))
: nullptr;
}

auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t { return typeId(); }
auto K::description() const noexcept -> std::string_view {
return "Performing Softmax using BANG";
}

}// namespace refactor::kernel
26 changes: 26 additions & 0 deletions src/04kernel/src/kernels/softmax/bang_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef KERNEL_SOFTMAX_BANG_HH
#define KERNEL_SOFTMAX_BANG_HH

#include "cnnl.h"
#include "cnrt.h"
#include "kernel/attributes/softmax_info.h"
#include "kernel/collectors/softmax.h"
namespace refactor::kernel {

struct SoftmaxBang final : public Kernel {
SoftmaxInfo info;

SoftmaxBang(SoftmaxInfo) noexcept;
static KernelBox build(SoftmaxInfo) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
#ifdef USE_BANG
RoutineWorkspace lower(Resources &) const noexcept final;
#endif
};

}// namespace refactor::kernel

#endif//KERNEL_SOFTMAX_BANG_HH
Loading