Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mmcv/ops/modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,13 @@ def forward(self, x):
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
x = x.type_as(offset)
weight = self.weight.type_as(x)
mask = mask.type_as(x)
return tv_deform_conv2d(
x,
offset,
self.weight,
weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
Expand Down
19 changes: 12 additions & 7 deletions tests/test_ops/test_modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _test_mdconv(self, dtype=torch.float, device='cuda'):
assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(),
dcn_offset_b_grad, 1e-2)

def _test_amp_mdconv(self, input_dtype=torch.float):
def _test_amp_mdconv(self, input_dtype=torch.float, device='cuda'):
"""The function to test amp released on pytorch 1.6.0.

The type of input data might be torch.float or torch.half,
Expand All @@ -84,10 +84,15 @@ def _test_amp_mdconv(self, input_dtype=torch.float):
Args:
input_dtype: torch.float or torch.half.
"""
if not torch.cuda.is_available():
if not torch.cuda.is_available() and device == 'cuda':
return
from mmcv.ops import ModulatedDeformConv2dPack
input = torch.tensor(input_t).cuda().type(input_dtype)
if device == 'mlu':
from mmcv.ops import \
ModulatedDeformConv2dPack_MLU as ModulatedDeformConv2dPack
else:
from mmcv.ops import ModulatedDeformConv2dPack

input = torch.tensor(input_t).to(device).type(input_dtype)
input.requires_grad = True

dcn = ModulatedDeformConv2dPack(
Expand All @@ -97,7 +102,7 @@ def _test_amp_mdconv(self, input_dtype=torch.float):
stride=1,
padding=1,
deform_groups=1,
bias=False).cuda()
bias=False).to(device)
dcn.weight.data.fill_(1.)
output = dcn(input)
output.sum().backward()
Expand Down Expand Up @@ -126,5 +131,5 @@ def test_mdconv(self):
if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True):
self._test_amp_mdconv(torch.float)
self._test_amp_mdconv(torch.half)
self._test_amp_mdconv(torch.float, device=device)
self._test_amp_mdconv(torch.half, device=device)