Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mmcv/device/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import ipu, mlu, mps
from . import ipu, mlu, mps, npu
from .scatter_gather import scatter, scatter_kwargs
from .utils import get_device

__all__ = ['mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs']
__all__ = [
'npu', 'mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs'
]
6 changes: 6 additions & 0 deletions mmcv/device/npu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright Huawei Technologies Co., Ltd. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from .data_parallel import NPUDataParallel
from .distributed import NPUDistributedDataParallel

__all__ = ['NPUDataParallel', 'NPUDistributedDataParallel']
59 changes: 59 additions & 0 deletions mmcv/device/npu/data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright Huawei Technologies Co., Ltd. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.

import sys

import torch

from mmcv.device.scatter_gather import scatter_kwargs
from mmcv.parallel import MMDataParallel


def _check_balance(*args, **kwargs):
return


# Since we do not have a similar hardware unit multi_processor
# on the NPU, the corresponding# devices_properties does not
# have this property and cannot be checked. So we masked the
# _check_balance function in DataParallel to make initialization pass.
for m in sys.modules:
if m.startswith('torch') or 'mmcv' in m:
if hasattr(sys.modules[m], '_check_balance'):
setattr(sys.modules[m], '_check_balance', _check_balance)


class NPUDataParallel(MMDataParallel):
"""The NPUDataParallel module that supports DataContainer.

NPUDataParallel is a class inherited from MMDataParall, which supports
NPU training and inference only.

The main differences with MMDataParallel:

- It only supports single-card of NPU, and only use first card to
run training and inference.

- It uses direct host-to-device copy instead of stream-background
scatter.

.. warning::
NPUDataParallel only supports single NPU training, if you need to
train with multiple NPUs, please use NPUDistributedDataParallel
instead. If you have multiple NPUs, you can toggle device_ids
parameters passed in for this function to specify the running device.

Args:
module (:class:`nn.Module`): Module to be encapsulated.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""

def __init__(self, *args, dim=0, **kwargs):
super().__init__(*args, dim=dim, **kwargs)
device_id = kwargs.get('device_ids', [0])[0]
self.device_ids = [device_id]
self.src_device_obj = torch.device(f'npu:{device_id}')
torch.npu.set_device(self.src_device_obj)

def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
26 changes: 26 additions & 0 deletions mmcv/device/npu/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright Huawei Technologies Co., Ltd. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.

from mmcv.device.scatter_gather import scatter_kwargs
from mmcv.parallel import MMDistributedDataParallel


class NPUDistributedDataParallel(MMDistributedDataParallel):
"""The DDP module supports DataContainer.
NPUDDP has one difference from MMDDP which moves data to NPU with coping
instead of scattering.
"""

def to_kwargs(self, inputs, kwargs, device_id):
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
# to move all tensors to device_id
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)

def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

def forward(self, *inputs, **kwargs):
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
return super().forward(*inputs[0], **kwargs[0])
12 changes: 10 additions & 2 deletions mmcv/device/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE,
IS_NPU_AVAILABLE)


def get_device() -> str:
"""Returns the currently existing device type.

.. note::
Since npu provides tools to automatically convert cuda functions,
we need to make judgments on npu first to avoid entering
the cuda branch when using npu.

Returns:
str: cuda | mlu | mps | cpu.
"""
if IS_CUDA_AVAILABLE:
if IS_NPU_AVAILABLE:
return 'npu'
elif IS_CUDA_AVAILABLE:
return 'cuda'
elif IS_MLU_AVAILABLE:
return 'mlu'
Expand Down
35 changes: 35 additions & 0 deletions mmcv/ops/csrc/common/pytorch_npu_helper.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/******************************************************************************
* Copyright (c) 2022 Huawei Technologies Co., Ltd
* All rights reserved.
*
* Licensed under the BSD 3-Clause License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSD-3-Clause
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
******************************************************************************/

#ifndef PYTORCH_NPU_HELPER_HPP_
#define PYTORCH_NPU_HELPER_HPP_

#include <torch_npu/csrc/aten/NPUNativeFunctions.h>
#include <torch_npu/csrc/framework/utils/CalcuOpUtil.h>
#include <torch_npu/csrc/framework/utils/OpAdapter.h>

#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"

#define NPU_NAME_SPACE at_npu::native

#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value)

#define CHECK_NPU(x) \
TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor")

#endif // PYTORCH_NPU_HELPER_HPP_
134 changes: 134 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;

void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
at::Tensor target_y = at::reshape(target, input.sizes());
target_y =
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
input.sizes());
}
OpCommand cmd;
cmd.Name("SigmoidFocalLoss")
.Input(input)
.Input(target_y)
.Input(weight_y)
.Output(output)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Run();
}

void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha);

void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma,
float alpha) {
at::Tensor target_y = at::reshape(target, input.sizes());
target_y =
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
at::Tensor grad_up = at::ones_like(input);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
input.sizes());
}
OpCommand cmd;
cmd.Name("SigmoidFocalLossGrad")
.Input(input)
.Input(target_y)
.Input(grad_up)
.Input(weight_y)
.Output(grad_input)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Run();
}

void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
Tensor weight, Tensor grad_input,
float gamma, float alpha);

void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
int64_t n_class = input.size(1);
at::Tensor target_y =
at_npu::native::NPUNativeFunctions::one_hot(target, n_class);
target_y =
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
input.sizes());
}
OpCommand cmd;
cmd.Name("SoftmaxFocalLoss")
.Input(input)
.Input(target_y)
.Input(weight_y)
.Output(output)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Run();
}

void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma,
float alpha);

void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
Tensor buff, Tensor grad_input,
float gamma, float alpha) {
int64_t n_class = input.size(1);
at::Tensor target_y =
at_npu::native::NPUNativeFunctions::one_hot(target, n_class);
target_y =
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
at::Tensor grad_up = at::ones_like(input);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
input.sizes());
}

OpCommand cmd;
cmd.Name("SoftmaxFocalLossGrad")
.Input(input)
.Input(target_y)
.Input(grad_up)
.Input(weight_y)
.Output(grad_input)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Run();
}

void softmax_focal_loss_backward_impl(Tensor input, Tensor target,
Tensor weight, Tensor buff,
Tensor grad_input, float gamma,
float alpha);

REGISTER_NPU_IMPL(sigmoid_focal_loss_forward_impl,
sigmoid_focal_loss_forward_npu);

REGISTER_NPU_IMPL(sigmoid_focal_loss_backward_impl,
sigmoid_focal_loss_backward_npu);

REGISTER_NPU_IMPL(softmax_focal_loss_forward_impl,
softmax_focal_loss_forward_npu);

REGISTER_NPU_IMPL(softmax_focal_loss_backward_impl,
softmax_focal_loss_backward_npu);
3 changes: 2 additions & 1 deletion mmcv/ops/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def forward(ctx,
weight: Optional[torch.Tensor] = None,
reduction='mean') -> torch.Tensor:

assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
assert isinstance(
target, (torch.Tensor, torch.LongTensor, torch.cuda.LongTensor))
assert input.dim() == 2
assert target.dim() == 1
assert input.size(0) == target.size(0)
Expand Down
10 changes: 9 additions & 1 deletion mmcv/runner/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)

from mmcv.utils import IS_MLU_AVAILABLE
from mmcv.utils import IS_MLU_AVAILABLE, IS_NPU_AVAILABLE


def _find_free_port() -> str:
Expand Down Expand Up @@ -58,6 +58,14 @@ def _init_dist_pytorch(backend: str, **kwargs) -> None:
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
elif IS_NPU_AVAILABLE:
import torch_npu # noqa: F401
torch.npu.set_device(rank)
dist.init_process_group(
backend='hccl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
Expand Down
7 changes: 5 additions & 2 deletions mmcv/runner/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@
import torch.nn as nn
from torch.nn.parameter import Parameter

from mmcv.utils import TORCH_VERSION, digit_version
from mmcv.utils import IS_NPU_AVAILABLE, TORCH_VERSION, digit_version
from .dist_utils import allreduce_grads as _allreduce_grads

try:
# If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported
# and used; otherwise, auto fp16 will adopt mmcv's implementation.
# Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16
# manually, so the behavior may not be consistent with real amp.
from torch.cuda.amp import autocast
if IS_NPU_AVAILABLE:
from torch.npu.amp import autocast
else:
from torch.cuda.amp import autocast
except ImportError:
pass

Expand Down
8 changes: 6 additions & 2 deletions mmcv/runner/hooks/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@
from torch import Tensor
from torch.nn.utils import clip_grad

from mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version
from mmcv.utils import (IS_NPU_AVAILABLE, TORCH_VERSION, _BatchNorm,
digit_version)
from ..dist_utils import allreduce_grads
from ..fp16_utils import LossScaler, wrap_fp16_model
from .hook import HOOKS, Hook

try:
# If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported
# and used; otherwise, auto fp16 will adopt mmcv's implementation.
from torch.cuda.amp import GradScaler
if IS_NPU_AVAILABLE:
from torch.npu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler
except ImportError:
pass

Expand Down
4 changes: 2 additions & 2 deletions mmcv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
]
else:
from .device_type import (IS_IPU_AVAILABLE, IS_MLU_AVAILABLE,
IS_MPS_AVAILABLE)
IS_MPS_AVAILABLE, IS_NPU_AVAILABLE)
from .env import collect_env
from .hub import load_url
from .logging import get_logger, print_log
Expand Down Expand Up @@ -77,5 +77,5 @@
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
'_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE',
'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE',
'IS_MPS_AVAILABLE', 'torch_meshgrid'
'IS_MPS_AVAILABLE', 'IS_NPU_AVAILABLE', 'torch_meshgrid'
]
Loading