Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
clean code
  • Loading branch information
ckirchhoff2021 committed Sep 20, 2022
commit 08f0a16a8ae227bd23866d7b6f87bc415d090758
16 changes: 8 additions & 8 deletions mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using namespace NPU_NAME_SPACE;
using namespace std;


void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {

at::Tensor target_y = at::reshape(target, input.sizes());
Expand All @@ -30,10 +30,10 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
.Run();
}

void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha);

void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma, float alpha) {

at::Tensor target_y = at::reshape(target, input.sizes());
Expand All @@ -58,10 +58,10 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
.Run();
}

void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, Tensor weight,
void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma, float alpha);

void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {

int64_t n_class = input.size(1);
Expand All @@ -73,7 +73,7 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
if(weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes());
}

OpCommand cmd;
cmd.Name("SoftmaxFocalLoss")
.Input(input)
Expand All @@ -87,7 +87,7 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
.Run();
}

void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma, float alpha);

void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor buff,
Expand All @@ -102,7 +102,7 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
if(weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes());
}

OpCommand cmd;
cmd.Name("SoftmaxFocalLossGrad")
.Input(input)
Expand Down