Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add npu extension and focal loss adapter
  • Loading branch information
ckirchhoff2021 committed Sep 19, 2022
commit 716b3b3693e97217bf908c34e7ca1544963d9699
42 changes: 42 additions & 0 deletions mmcv/ops/csrc/common/pytorch_npu_helper.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/******************************************************************************
* Copyright (c) 2022 Huawei Technologies Co., Ltd
* All rights reserved.
*
* Licensed under the BSD 3-Clause License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSD-3-Clause
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
******************************************************************************/

#ifndef PYTORCH_NPU_HELPER_HPP_
#define PYTORCH_NPU_HELPER_HPP_

#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"

#ifdef MMCV_WITH_NPU
#include <torch_npu/csrc/framework/utils/CalcuOpUtil.h>
#include <torch_npu/csrc/framework/utils/OpAdapter.h>
#include <torch_npu/csrc/aten/NPUNativeFunctions.h>
#define NPU_NAME_SPACE at_npu::native
#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value)
#define CHECK_NPU(x) \
TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor")
#else
// for torch 1.5.0 adapter only
#include <torch/csrc/ATen/native/npu/utils/OpAdapter.h>
#include <torch/csrc/ATen/native/npu/utils/CalcuOpUtil.h>
#define NPU_NAME_SPACE at::native::npu
#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, NPU, value);
#define CHECK_NPU(x) \
TORCH_CHECK(x.device().type() == at::kNPU, #x " must be a NPU tensor")
#endif

#endif // PYTORCH_NPU_HELPER_HPP_
128 changes: 128 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#include <iostream>
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;


void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {

at::Tensor target_y = at::reshape(target, input.sizes());
target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
at::Tensor grad_up = at::ones_like(input);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if(weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes());
}

OpCommand cmd;
cmd.Name("SigmoidFocalLoss")
.Input(input)
.Input(target_y)
.Input(grad_up)
.Input(weight_y)
.Output(grad_input)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Run();
}

void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha);

void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma, float alpha) {

at::Tensor target_y = at::reshape(target, input.sizes());
target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
at::Tensor grad_up = at::ones_like(input);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if(weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes());
}

OpCommand cmd;
cmd.Name("SigmoidFocalLossGrad")
.Input(input)
.Input(target_y)
.Input(grad_up)
.Input(weight_y)
.Output(grad_input)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Run();
}

void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma, float alpha);

void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {

int64_t n_class = input.size(1);
at::Tensor target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class);
target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
at::Tensor grad_up = at::ones_like(input);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if(weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes());
}

OpCommand cmd;
cmd.Name("SoftmaxFocalLoss")
.Input(input)
.Input(target_y)
.Input(grad_up)
.Input(weight_y)
.Output(grad_input)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Run();
}

void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma, float alpha);

void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor buff,
Tensor grad_input, float gamma, float alpha) {

int64_t n_class = input.size(1);
at::Tensor target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class);
target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
at::Tensor grad_up = at::ones_like(input);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if(weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes());
}

OpCommand cmd;
cmd.Name("SoftmaxFocalLossGrad")
.Input(input)
.Input(target_y)
.Input(grad_up)
.Input(weight_y)
.Output(grad_input)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Run();
}

void softmax_focal_loss_backward_impl(Tensor input, Tensor target, Tensor weight, Tensor buff,
Tensor grad_input, float gamma, float alpha);

REGISTER_NPU_IMPL(sigmoid_focal_loss_forward_impl, sigmoid_focal_loss_forward_npu);

REGISTER_NPU_IMPL(sigmoid_focal_loss_backward_impl, sigmoid_focal_loss_backward_npu);

REGISTER_NPU_IMPL(softmax_focal_loss_forward_impl, softmax_focal_loss_forward_npu);

REGISTER_NPU_IMPL(softmax_focal_loss_backward_impl, softmax_focal_loss_backward_npu);
3 changes: 2 additions & 1 deletion mmcv/ops/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def forward(ctx,
weight: Optional[torch.Tensor] = None,
reduction='mean') -> torch.Tensor:

assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
assert isinstance(
Comment thread
wangjiangben-hw marked this conversation as resolved.
Outdated
target, (torch.Tensor, torch.LongTensor, torch.cuda.LongTensor))
Comment thread
wangjiangben-hw marked this conversation as resolved.
Outdated
assert input.dim() == 2
assert target.dim() == 1
assert input.size(0) == target.size(0)
Expand Down
29 changes: 28 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def get_extensions():
if is_rocm_pytorch or torch.cuda.is_available() or os.getenv(
'FORCE_CUDA', '0') == '1':
if is_rocm_pytorch:
define_macros += [('HIP_DIFF', None)]
define_macros += [('MMCV_WITH_HIP', None)]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we remove HIP_DIFF?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define_macros += [('MMCV_WITH_CUDA', None)]
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
Expand All @@ -289,6 +289,7 @@ def get_extensions():
glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cu') + \
glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cpp')
extension = CUDAExtension
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/pytorch'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda'))
elif (hasattr(torch, 'is_mlu_available') and
Expand Down Expand Up @@ -329,6 +330,32 @@ def get_extensions():
extension = CppExtension
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mps'))
elif (os.getenv('FORCE_NPU', '0') == '1'):
print(f'Compiling {ext_name} only with CPU and NPU')
try:
has_npu = torch.npu.is_available()
print('torch_npu version 1.5 is available. ', has_npu)
extension = CppExtension
except:
try:
import torch_npu
from torch_npu.utils.cpp_extension import NpuExtension
has_npu = torch_npu.npu.is_available()
print('torch_npu version 1.8 is available.: ', has_npu)
define_macros += [('MMCV_WITH_NPU', None)]
extension = NpuExtension
except:
print('can not find any torch_npu')
return extensions

# src
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \
glob.glob('./mmcv/ops/csrc/common/npu/*.cpp') + \
glob.glob('./mmcv/ops/csrc/pytorch/npu/*.cpp')

include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/npu'))
else:
print(f'Compiling {ext_name} only with CPU')
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
Expand Down