-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[Feature] Support engine with NPU backend. #2262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
093f462
init npu
wangjiangben-hw ba9c36b
Merge pull request #2261 from wangjiangben-hw/npu-dev
ckirchhoff2021 716b3b3
add npu extension and focal loss adapter
ckirchhoff2021 6e53b3f
clean code
ckirchhoff2021 08f0a16
clean code
ckirchhoff2021 da659cb
clean code
ckirchhoff2021 448476e
clean code
ckirchhoff2021 afbd351
fix autocast bugs on npu (#2273)
wangjiangben-hw 26f35e0
code format
ckirchhoff2021 58618ac
code format
ckirchhoff2021 d75dcb1
code format
ckirchhoff2021 7af2a6f
bug fix
ckirchhoff2021 268ff0e
pytorch_npu_helper.hpp clean code
ckirchhoff2021 9b0133c
Merge pull request #2269 from ckirchhoff2021/npu-dev
ckirchhoff2021 a043541
Npu dev (#2306)
wangjiangben-hw 90fc3dc
raise ImportError when compile with npu
wangjiangben-hw 6ffdb57
add npu test case (#2307)
wangjiangben-hw 3f11861
Update focal_loss.py
wangjiangben-hw c81daa3
add comment
wangjiangben-hw 1cee865
clean lint
wangjiangben-hw 5841f92
update dtype assert
wangjiangben-hw 4ce938c
update DDP forward and comment
wangjiangben-hw ea1d8f8
fix bug
wangjiangben-hw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,8 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| from . import ipu, mlu, mps | ||
| from . import ipu, mlu, mps, npu | ||
| from .scatter_gather import scatter, scatter_kwargs | ||
| from .utils import get_device | ||
|
|
||
| __all__ = ['mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs'] | ||
| __all__ = [ | ||
| 'npu', 'mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs' | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| # Copyright Huawei Technologies Co., Ltd. All rights reserved. | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| from .data_parallel import NPUDataParallel | ||
| from .distributed import NPUDistributedDataParallel | ||
|
|
||
| __all__ = ['NPUDataParallel', 'NPUDistributedDataParallel'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| # Copyright Huawei Technologies Co., Ltd. All rights reserved. | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
|
|
||
| import sys | ||
|
|
||
| import torch | ||
|
|
||
| from mmcv.device.scatter_gather import scatter_kwargs | ||
| from mmcv.parallel import MMDataParallel | ||
|
|
||
|
|
||
| def _check_balance(*args, **kwargs): | ||
| return | ||
|
|
||
|
|
||
| # Since we do not have a similar hardware unit multi_processor | ||
| # on the NPU, the corresponding# devices_properties does not | ||
| # have this property and cannot be checked. So we masked the | ||
| # _check_balance function in DataParallel to make initialization pass. | ||
| for m in sys.modules: | ||
| if m.startswith('torch') or 'mmcv' in m: | ||
| if hasattr(sys.modules[m], '_check_balance'): | ||
| setattr(sys.modules[m], '_check_balance', _check_balance) | ||
|
|
||
|
|
||
| class NPUDataParallel(MMDataParallel): | ||
| """The NPUDataParallel module that supports DataContainer. | ||
|
|
||
| NPUDataParallel is a class inherited from MMDataParall, which supports | ||
| NPU training and inference only. | ||
|
|
||
| The main differences with MMDataParallel: | ||
|
|
||
| - It only supports single-card of NPU, and only use first card to | ||
| run training and inference. | ||
|
|
||
| - It uses direct host-to-device copy instead of stream-background | ||
| scatter. | ||
|
|
||
| .. warning:: | ||
| NPUDataParallel only supports single NPU training, if you need to | ||
| train with multiple NPUs, please use NPUDistributedDataParallel | ||
| instead. If you have multiple NPUs, you can toggle device_ids | ||
| parameters passed in for this function to specify the running device. | ||
|
|
||
| Args: | ||
| module (:class:`nn.Module`): Module to be encapsulated. | ||
| dim (int): Dimension used to scatter the data. Defaults to 0. | ||
| """ | ||
|
|
||
| def __init__(self, *args, dim=0, **kwargs): | ||
| super().__init__(*args, dim=dim, **kwargs) | ||
| device_id = kwargs.get('device_ids', [0])[0] | ||
| self.device_ids = [device_id] | ||
| self.src_device_obj = torch.device(f'npu:{device_id}') | ||
| torch.npu.set_device(self.src_device_obj) | ||
|
|
||
| def scatter(self, inputs, kwargs, device_ids): | ||
| return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| # Copyright Huawei Technologies Co., Ltd. All rights reserved. | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
|
|
||
| from mmcv.device.scatter_gather import scatter_kwargs | ||
| from mmcv.parallel import MMDistributedDataParallel | ||
|
|
||
|
|
||
| class NPUDistributedDataParallel(MMDistributedDataParallel): | ||
| """The DDP module supports DataContainer. | ||
| NPUDDP has one difference from MMDDP which moves data to NPU with coping | ||
| instead of scattering. | ||
| """ | ||
|
|
||
| def to_kwargs(self, inputs, kwargs, device_id): | ||
| # Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8 | ||
| # to move all tensors to device_id | ||
| return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim) | ||
|
|
||
| def scatter(self, inputs, kwargs, device_ids): | ||
| return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) | ||
|
|
||
| def forward(self, *inputs, **kwargs): | ||
| if self.device_ids: | ||
| inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) | ||
| return super().forward(*inputs[0], **kwargs[0]) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| /****************************************************************************** | ||
| * Copyright (c) 2022 Huawei Technologies Co., Ltd | ||
| * All rights reserved. | ||
| * | ||
| * Licensed under the BSD 3-Clause License (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * https://opensource.org/licenses/BSD-3-Clause | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| ******************************************************************************/ | ||
|
|
||
| #ifndef PYTORCH_NPU_HELPER_HPP_ | ||
| #define PYTORCH_NPU_HELPER_HPP_ | ||
|
|
||
| #include <torch_npu/csrc/aten/NPUNativeFunctions.h> | ||
| #include <torch_npu/csrc/framework/utils/CalcuOpUtil.h> | ||
| #include <torch_npu/csrc/framework/utils/OpAdapter.h> | ||
|
|
||
| #include "pytorch_cpp_helper.hpp" | ||
| #include "pytorch_device_registry.hpp" | ||
|
|
||
| #define NPU_NAME_SPACE at_npu::native | ||
|
|
||
| #define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value) | ||
|
|
||
| #define CHECK_NPU(x) \ | ||
| TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor") | ||
|
|
||
| #endif // PYTORCH_NPU_HELPER_HPP_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| #include "pytorch_npu_helper.hpp" | ||
|
|
||
| using namespace NPU_NAME_SPACE; | ||
|
|
||
| void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, | ||
| Tensor output, float gamma, float alpha) { | ||
| at::Tensor target_y = at::reshape(target, input.sizes()); | ||
| target_y = | ||
| at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); | ||
| int64_t weight_size = weight.size(0); | ||
| at::Tensor weight_y = at::ones_like(input); | ||
| if (weight_size > 0) { | ||
| weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, | ||
| input.sizes()); | ||
| } | ||
| OpCommand cmd; | ||
| cmd.Name("SigmoidFocalLoss") | ||
| .Input(input) | ||
| .Input(target_y) | ||
| .Input(weight_y) | ||
| .Output(output) | ||
| .Attr("gamma", gamma) | ||
| .Attr("alpha", alpha) | ||
| .Attr("reduction", "none") | ||
| .Run(); | ||
| } | ||
|
|
||
| void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, | ||
| Tensor output, float gamma, float alpha); | ||
|
|
||
| void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, | ||
| Tensor grad_input, float gamma, | ||
| float alpha) { | ||
| at::Tensor target_y = at::reshape(target, input.sizes()); | ||
| target_y = | ||
| at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); | ||
| at::Tensor grad_up = at::ones_like(input); | ||
| int64_t weight_size = weight.size(0); | ||
| at::Tensor weight_y = at::ones_like(input); | ||
| if (weight_size > 0) { | ||
| weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, | ||
| input.sizes()); | ||
| } | ||
| OpCommand cmd; | ||
| cmd.Name("SigmoidFocalLossGrad") | ||
| .Input(input) | ||
| .Input(target_y) | ||
| .Input(grad_up) | ||
| .Input(weight_y) | ||
| .Output(grad_input) | ||
| .Attr("gamma", gamma) | ||
| .Attr("alpha", alpha) | ||
| .Attr("reduction", "none") | ||
| .Run(); | ||
| } | ||
|
|
||
| void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, | ||
| Tensor weight, Tensor grad_input, | ||
| float gamma, float alpha); | ||
|
|
||
| void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, | ||
| Tensor output, float gamma, float alpha) { | ||
| int64_t n_class = input.size(1); | ||
| at::Tensor target_y = | ||
| at_npu::native::NPUNativeFunctions::one_hot(target, n_class); | ||
| target_y = | ||
| at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); | ||
| int64_t weight_size = weight.size(0); | ||
| at::Tensor weight_y = at::ones_like(input); | ||
| if (weight_size > 0) { | ||
| weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, | ||
| input.sizes()); | ||
| } | ||
| OpCommand cmd; | ||
| cmd.Name("SoftmaxFocalLoss") | ||
| .Input(input) | ||
| .Input(target_y) | ||
| .Input(weight_y) | ||
| .Output(output) | ||
| .Attr("gamma", gamma) | ||
| .Attr("alpha", alpha) | ||
| .Attr("reduction", "none") | ||
| .Run(); | ||
| } | ||
|
|
||
| void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, | ||
| Tensor grad_input, float gamma, | ||
| float alpha); | ||
|
|
||
| void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, | ||
| Tensor buff, Tensor grad_input, | ||
| float gamma, float alpha) { | ||
| int64_t n_class = input.size(1); | ||
| at::Tensor target_y = | ||
| at_npu::native::NPUNativeFunctions::one_hot(target, n_class); | ||
| target_y = | ||
| at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); | ||
| at::Tensor grad_up = at::ones_like(input); | ||
| int64_t weight_size = weight.size(0); | ||
| at::Tensor weight_y = at::ones_like(input); | ||
| if (weight_size > 0) { | ||
| weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, | ||
| input.sizes()); | ||
| } | ||
|
|
||
| OpCommand cmd; | ||
| cmd.Name("SoftmaxFocalLossGrad") | ||
| .Input(input) | ||
| .Input(target_y) | ||
| .Input(grad_up) | ||
| .Input(weight_y) | ||
| .Output(grad_input) | ||
| .Attr("gamma", gamma) | ||
| .Attr("alpha", alpha) | ||
| .Attr("reduction", "none") | ||
| .Run(); | ||
| } | ||
|
|
||
| void softmax_focal_loss_backward_impl(Tensor input, Tensor target, | ||
| Tensor weight, Tensor buff, | ||
| Tensor grad_input, float gamma, | ||
| float alpha); | ||
|
|
||
| REGISTER_NPU_IMPL(sigmoid_focal_loss_forward_impl, | ||
| sigmoid_focal_loss_forward_npu); | ||
|
|
||
| REGISTER_NPU_IMPL(sigmoid_focal_loss_backward_impl, | ||
| sigmoid_focal_loss_backward_npu); | ||
|
|
||
| REGISTER_NPU_IMPL(softmax_focal_loss_forward_impl, | ||
| softmax_focal_loss_forward_npu); | ||
|
|
||
| REGISTER_NPU_IMPL(softmax_focal_loss_backward_impl, | ||
| softmax_focal_loss_backward_npu); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.