Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,8 @@ jobs:
strategy:
matrix:
python-version: [3.7]
torch: [1.3.1, 1.5.1+cu101, 1.6.0+cu101, 1.7.0+cu101, 1.8.0+cu101]
torch: [1.5.1+cu101, 1.6.0+cu101, 1.7.0+cu101, 1.8.0+cu101]
include:
- torch: 1.3.1
torchvision: 0.4.2
- torch: 1.5.1+cu101
torchvision: 0.6.1+cu101
- torch: 1.6.0+cu101
Expand Down Expand Up @@ -362,10 +360,8 @@ jobs:
runs-on: macos-latest
strategy:
matrix:
torch: [1.3.1, 1.5.1, 1.6.0, 1.7.0, 1.8.0, 1.9.0]
torch: [1.5.1, 1.6.0, 1.7.0, 1.8.0, 1.9.0]
include:
- torch: 1.3.1
torchvision: 0.4.2
- torch: 1.5.1
torchvision: 0.6.1
- torch: 1.6.0
Expand Down
12 changes: 6 additions & 6 deletions mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ __device__ inline bool devIoU(float const *const a, float const *const b,
return interS > threshold * (Sa + Sb - interS);
}

__global__ void nms_cuda(const int n_boxes, const float iou_threshold,
const int offset, const float *dev_boxes,
unsigned long long *dev_mask) {
__global__ static void nms_cuda(const int n_boxes, const float iou_threshold,
const int offset, const float *dev_boxes,
unsigned long long *dev_mask) {
int blocks = (n_boxes + threadsPerBlock - 1) / threadsPerBlock;
CUDA_2D_KERNEL_BLOCK_LOOP(col_start, blocks, row_start, blocks) {
const int tid = threadIdx.x;
Expand Down Expand Up @@ -73,9 +73,9 @@ __global__ void nms_cuda(const int n_boxes, const float iou_threshold,
}
}

__global__ void gather_keep_from_mask(bool *keep,
const unsigned long long *dev_mask,
const int n_boxes) {
__global__ static void gather_keep_from_mask(bool *keep,
const unsigned long long *dev_mask,
const int n_boxes) {
const int col_blocks = (n_boxes + threadsPerBlock - 1) / threadsPerBlock;
const int tid = threadIdx.x;

Expand Down
30 changes: 14 additions & 16 deletions mmcv/ops/csrc/parrots/cudabind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,14 +570,12 @@ void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
const Tensor boxes_b,
Tensor ans_overlap);

void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes,
unsigned long long* mask,
int boxes_num,
void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep,
Tensor& keep_num,
float nms_overlap_thresh);

void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes,
unsigned long long* mask,
int boxes_num,
void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep,
Tensor& keep_num,
float nms_overlap_thresh);

void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a,
Expand All @@ -587,28 +585,28 @@ void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a,
ans_overlap);
};

void iou3d_nms3d_forward_cuda(const Tensor boxes, unsigned long long* mask,
int boxes_num, float nms_overlap_thresh) {
IoU3DNMS3DForwardCUDAKernelLauncher(boxes, mask, boxes_num,
void iou3d_nms3d_forward_cuda(const Tensor boxes, Tensor& keep,
Tensor& keep_num, float nms_overlap_thresh) {
IoU3DNMS3DForwardCUDAKernelLauncher(boxes, keep, keep_num,
nms_overlap_thresh);
};

void iou3d_nms3d_normal_forward_cuda(const Tensor boxes,
unsigned long long* mask, int boxes_num,
void iou3d_nms3d_normal_forward_cuda(const Tensor boxes, Tensor& keep,
Tensor& keep_num,
float nms_overlap_thresh) {
IoU3DNMS3DNormalForwardCUDAKernelLauncher(boxes, mask, boxes_num,
IoU3DNMS3DNormalForwardCUDAKernelLauncher(boxes, keep, keep_num,
nms_overlap_thresh);
};

void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a,
const int num_b, const Tensor boxes_b,
Tensor ans_overlap);

void iou3d_nms3d_forward_impl(const Tensor boxes, unsigned long long* mask,
int boxes_num, float nms_overlap_thresh);
void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor& keep,
Tensor& keep_num, float nms_overlap_thresh);

void iou3d_nms3d_normal_forward_impl(const Tensor boxes,
unsigned long long* mask, int boxes_num,
void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor& keep,
Tensor& keep_num,
float nms_overlap_thresh);

REGISTER_DEVICE_IMPL(iou3d_boxes_overlap_bev_forward_impl, CUDA,
Expand Down
85 changes: 8 additions & 77 deletions mmcv/ops/csrc/parrots/iou3d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a,
num_b, boxes_b, ans_overlap);
}

void iou3d_nms3d_forward_impl(const Tensor boxes, unsigned long long *mask,
int boxes_num, float nms_overlap_thresh) {
DISPATCH_DEVICE_IMPL(iou3d_nms3d_forward_impl, boxes, mask, boxes_num,
void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep,
Tensor &keep_num, float nms_overlap_thresh) {
DISPATCH_DEVICE_IMPL(iou3d_nms3d_forward_impl, boxes, keep, keep_num,
nms_overlap_thresh);
}

void iou3d_nms3d_normal_forward_impl(const Tensor boxes,
unsigned long long *mask, int boxes_num,
void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor &keep,
Tensor &keep_num,
float nms_overlap_thresh) {
DISPATCH_DEVICE_IMPL(iou3d_nms3d_normal_forward_impl, boxes, mask, boxes_num,
DISPATCH_DEVICE_IMPL(iou3d_nms3d_normal_forward_impl, boxes, keep, keep_num,
nms_overlap_thresh);
}

Expand All @@ -51,41 +51,7 @@ void iou3d_nms3d_forward(Tensor boxes, Tensor keep, Tensor keep_num,
CHECK_CONTIGUOUS(boxes);
CHECK_CONTIGUOUS(keep);

int boxes_num = boxes.size(0);
int64_t *keep_data = keep.data_ptr<int64_t>();
int64_t *keep_num_data = keep_num.data_ptr<int64_t>();

const int col_blocks =
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;

Tensor mask =
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
unsigned long long *mask_data =
(unsigned long long *)mask.data_ptr<int64_t>();
iou3d_nms3d_forward_impl(boxes, mask_data, boxes_num, nms_overlap_thresh);

at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long *mask_host =
(unsigned long long *)mask_cpu.data_ptr<int64_t>();

std::vector<unsigned long long> remv_cpu(col_blocks);
memset(&remv_cpu[0], 0, sizeof(unsigned long long) * col_blocks);

int num_to_keep = 0;

for (int i = 0; i < boxes_num; i++) {
int nblock = i / THREADS_PER_BLOCK_NMS;
int inblock = i % THREADS_PER_BLOCK_NMS;

if (!(remv_cpu[nblock] & (1ULL << inblock))) {
keep_data[num_to_keep++] = i;
unsigned long long *p = &mask_host[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv_cpu[j] |= p[j];
}
}
*keep_num_data = num_to_keep;
}
iou3d_nms3d_forward_impl(boxes, keep, keep_num, nms_overlap_thresh);
}

void iou3d_nms3d_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num,
Expand All @@ -96,40 +62,5 @@ void iou3d_nms3d_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num,
CHECK_CONTIGUOUS(boxes);
CHECK_CONTIGUOUS(keep);

int boxes_num = boxes.size(0);
int64_t *keep_data = keep.data_ptr<int64_t>();
int64_t *keep_num_data = keep_num.data_ptr<int64_t>();

const int col_blocks =
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;

Tensor mask =
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
unsigned long long *mask_data =
(unsigned long long *)mask.data_ptr<int64_t>();
iou3d_nms3d_normal_forward_impl(boxes, mask_data, boxes_num,
nms_overlap_thresh);

at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long *mask_host =
(unsigned long long *)mask_cpu.data_ptr<int64_t>();

std::vector<unsigned long long> remv_cpu(col_blocks);
memset(&remv_cpu[0], 0, sizeof(unsigned long long) * col_blocks);
int num_to_keep = 0;

for (int i = 0; i < boxes_num; i++) {
int nblock = i / THREADS_PER_BLOCK_NMS;
int inblock = i % THREADS_PER_BLOCK_NMS;

if (!(remv_cpu[nblock] & (1ULL << inblock))) {
keep_data[num_to_keep++] = i;
unsigned long long *p = &mask_host[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv_cpu[j] |= p[j];
}
}
}

*keep_num_data = num_to_keep;
iou3d_nms3d_normal_forward_impl(boxes, keep, keep_num, nms_overlap_thresh);
}
30 changes: 14 additions & 16 deletions mmcv/ops/csrc/pytorch/cuda/cudabind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,14 +570,12 @@ void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
const Tensor boxes_b,
Tensor ans_overlap);

void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes,
unsigned long long* mask,
int boxes_num,
void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep,
Tensor& keep_num,
float nms_overlap_thresh);

void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes,
unsigned long long* mask,
int boxes_num,
void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep,
Tensor& keep_num,
float nms_overlap_thresh);

void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a,
Expand All @@ -587,28 +585,28 @@ void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a,
ans_overlap);
};

void iou3d_nms3d_forward_cuda(const Tensor boxes, unsigned long long* mask,
int boxes_num, float nms_overlap_thresh) {
IoU3DNMS3DForwardCUDAKernelLauncher(boxes, mask, boxes_num,
void iou3d_nms3d_forward_cuda(const Tensor boxes, Tensor& keep,
Tensor& keep_num, float nms_overlap_thresh) {
IoU3DNMS3DForwardCUDAKernelLauncher(boxes, keep, keep_num,
nms_overlap_thresh);
};

void iou3d_nms3d_normal_forward_cuda(const Tensor boxes,
unsigned long long* mask, int boxes_num,
void iou3d_nms3d_normal_forward_cuda(const Tensor boxes, Tensor& keep,
Tensor& keep_num,
float nms_overlap_thresh) {
IoU3DNMS3DNormalForwardCUDAKernelLauncher(boxes, mask, boxes_num,
IoU3DNMS3DNormalForwardCUDAKernelLauncher(boxes, keep, keep_num,
nms_overlap_thresh);
};

void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a,
const int num_b, const Tensor boxes_b,
Tensor ans_overlap);

void iou3d_nms3d_forward_impl(const Tensor boxes, unsigned long long* mask,
int boxes_num, float nms_overlap_thresh);
void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor& keep,
Tensor& keep_num, float nms_overlap_thresh);

void iou3d_nms3d_normal_forward_impl(const Tensor boxes,
unsigned long long* mask, int boxes_num,
void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor& keep,
Tensor& keep_num,
float nms_overlap_thresh);

REGISTER_DEVICE_IMPL(iou3d_boxes_overlap_bev_forward_impl, CUDA,
Expand Down
53 changes: 45 additions & 8 deletions mmcv/ops/csrc/pytorch/cuda/iou3d_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ All Rights Reserved 2019-2020.
#include <stdio.h>

#include "iou3d_cuda_kernel.cuh"
#include "nms_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"

void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
Expand All @@ -32,36 +33,72 @@ void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
AT_CUDA_CHECK(cudaGetLastError());
}

void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes,
unsigned long long *mask,
int boxes_num,
void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep,
Tensor& keep_num,
float nms_overlap_thresh) {
using namespace at::indexing;
at::cuda::CUDAGuard device_guard(boxes.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

int boxes_num = boxes.size(0);

const int col_blocks =
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
Tensor mask =
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));

dim3 blocks(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS),
GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS));
dim3 threads(THREADS_PER_BLOCK_NMS);

iou3d_nms3d_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(
boxes_num, nms_overlap_thresh, boxes.data_ptr<float>(), mask);
boxes_num, nms_overlap_thresh, boxes.data_ptr<float>(),
(unsigned long long*)mask.data_ptr<int64_t>());

at::Tensor keep_t = at::zeros(
{boxes_num}, boxes.options().dtype(at::kBool).device(at::kCUDA));
gather_keep_from_mask<<<1, min(col_blocks, THREADS_PER_BLOCK),
col_blocks * sizeof(unsigned long long), stream>>>(
keep_t.data_ptr<bool>(), (unsigned long long*)mask.data_ptr<int64_t>(),
boxes_num);

auto keep_data = keep_t.nonzero().index({Slice(), 0});
keep_num.fill_(at::Scalar(keep_data.size(0)));
keep.index_put_({Slice(0, keep_data.size(0))}, keep_data);
AT_CUDA_CHECK(cudaGetLastError());
}

void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes,
unsigned long long *mask,
int boxes_num,
void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep,
Tensor& keep_num,
float nms_overlap_thresh) {
using namespace at::indexing;
at::cuda::CUDAGuard device_guard(boxes.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

int boxes_num = boxes.size(0);

const int col_blocks =
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
Tensor mask =
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));

dim3 blocks(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS),
GET_BLOCKS(boxes_num, THREADS_PER_BLOCK_NMS));
dim3 threads(THREADS_PER_BLOCK_NMS);

iou3d_nms3d_normal_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(
boxes_num, nms_overlap_thresh, boxes.data_ptr<float>(), mask);
boxes_num, nms_overlap_thresh, boxes.data_ptr<float>(),
(unsigned long long*)mask.data_ptr<int64_t>());

at::Tensor keep_t = at::zeros(
{boxes_num}, boxes.options().dtype(at::kBool).device(at::kCUDA));
gather_keep_from_mask<<<1, min(col_blocks, THREADS_PER_BLOCK),
col_blocks * sizeof(unsigned long long), stream>>>(
keep_t.data_ptr<bool>(), (unsigned long long*)mask.data_ptr<int64_t>(),
boxes_num);

auto keep_data = keep_t.nonzero().index({Slice(), 0});
keep_num.fill_(at::Scalar(keep_data.size(0)));
keep.index_put_({Slice(0, keep_data.size(0))}, keep_data);
AT_CUDA_CHECK(cudaGetLastError());
}
Loading