From eb015708bb0ae0468367b60ec8d809a5f1ec34fe Mon Sep 17 00:00:00 2001 From: alanvinx Date: Wed, 7 Aug 2024 14:36:07 +0200 Subject: [PATCH 1/5] depth regularization + antialiasing --- cuda_rasterizer/auxiliary.h | 17 ++- cuda_rasterizer/backward.cu | 162 ++++++++++++++++++------ cuda_rasterizer/backward.h | 10 +- cuda_rasterizer/forward.cu | 48 +++++-- cuda_rasterizer/forward.h | 6 +- cuda_rasterizer/rasterizer.h | 6 +- cuda_rasterizer/rasterizer_impl.cu | 20 ++- diff_gaussian_rasterization/__init__.py | 34 ++--- rasterize_points.cu | 29 ++++- rasterize_points.h | 6 +- 10 files changed, 251 insertions(+), 87 deletions(-) diff --git a/cuda_rasterizer/auxiliary.h b/cuda_rasterizer/auxiliary.h index 4d4b9b78..7f6b7548 100644 --- a/cuda_rasterizer/auxiliary.h +++ b/cuda_rasterizer/auxiliary.h @@ -17,7 +17,7 @@ #define BLOCK_SIZE (BLOCK_X * BLOCK_Y) #define NUM_WARPS (BLOCK_SIZE/32) - +#define DGR_FIX_AA // Spherical harmonics coefficients __device__ const float SH_C0 = 0.28209479177387814f; __device__ const float SH_C1 = 0.4886025119029199f; @@ -55,6 +55,19 @@ __forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& r }; } +__forceinline__ __device__ void getRect(const float2 p, int2 ext_rect, uint2& rect_min, uint2& rect_max, dim3 grid) +{ + rect_min = { + min(grid.x, max((int)0, (int)((p.x - ext_rect.x) / BLOCK_X))), + min(grid.y, max((int)0, (int)((p.y - ext_rect.y) / BLOCK_Y))) + }; + rect_max = { + min(grid.x, max((int)0, (int)((p.x + ext_rect.x + BLOCK_X - 1) / BLOCK_X))), + min(grid.y, max((int)0, (int)((p.y + ext_rect.y + BLOCK_Y - 1) / BLOCK_Y))) + }; +} + + __forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix) { float3 transformed = { @@ -172,4 +185,4 @@ throw std::runtime_error(cudaGetErrorString(ret)); \ } \ } -#endif \ No newline at end of file +#endif diff --git a/cuda_rasterizer/backward.cu b/cuda_rasterizer/backward.cu index 4aa41e1c..7a609ac8 100644 --- a/cuda_rasterizer/backward.cu +++ b/cuda_rasterizer/backward.cu @@ -15,6 +15,9 @@ #include namespace cg = cooperative_groups; +__device__ __forceinline__ float sq(float x) { return x * x; } + + // Backward pass for conversion of spherical harmonics to RGB for // each Gaussian. __device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs) @@ -148,7 +151,10 @@ __global__ void computeCov2DCUDA(int P, const float h_x, float h_y, const float tan_fovx, float tan_fovy, const float* view_matrix, + const float* opacities, const float* dL_dconics, + float* dL_dopacity, + const float* dL_dinvdepth, float3* dL_dmeans, float* dL_dcov) { @@ -194,12 +200,49 @@ __global__ void computeCov2DCUDA(int P, glm::mat3 cov2D = glm::transpose(T) * glm::transpose(Vrk) * T; // Use helper variables for 2D covariance entries. More compact. - float a = cov2D[0][0] += 0.3f; - float b = cov2D[0][1]; - float c = cov2D[1][1] += 0.3f; + float c_xx = cov2D[0][0]; + float c_xy = cov2D[0][1]; + float c_yy = cov2D[1][1]; + + constexpr float h_var = 0.3f; +#ifdef DGR_FIX_AA + const float det_cov = c_xx * c_yy - c_xy * c_xy; + c_xx += h_var; + c_yy += h_var; + const float det_cov_plus_h_cov = c_xx * c_yy - c_xy * c_xy; + const float h_convolution_scaling = sqrt(max(0.000025f, det_cov / det_cov_plus_h_cov)); // max for numerical stability + const float dL_dopacity_v = dL_dopacity[idx]; + const float d_h_convolution_scaling = dL_dopacity_v * opacities[idx]; + dL_dopacity[idx] = dL_dopacity_v * h_convolution_scaling; + const float d_inside_root = (det_cov / det_cov_plus_h_cov) <= 0.000025f ? 0.f : d_h_convolution_scaling / (2 * h_convolution_scaling); +#else + c_xx += h_var; + c_yy += h_var; +#endif + + float dL_dc_xx = 0; + float dL_dc_xy = 0; + float dL_dc_yy = 0; +#ifdef DGR_FIX_AA + { + // https://www.wolframalpha.com/input?i=d+%28%28x*y+-+z%5E2%29%2F%28%28x%2Bw%29*%28y%2Bw%29+-+z%5E2%29%29+%2Fdx + // https://www.wolframalpha.com/input?i=d+%28%28x*y+-+z%5E2%29%2F%28%28x%2Bw%29*%28y%2Bw%29+-+z%5E2%29%29+%2Fdz + const float x = c_xx; + const float y = c_yy; + const float z = c_xy; + const float w = h_var; + const float denom_f = d_inside_root / sq(w * w + w * (x + y) + x * y - z * z); + const float dL_dx = w * (w * y + y * y + z * z) * denom_f; + const float dL_dy = w * (w * x + x * x + z * z) * denom_f; + const float dL_dz = -2.f * w * z * (w + x + y) * denom_f; + dL_dc_xx = dL_dx; + dL_dc_yy = dL_dy; + dL_dc_xy = dL_dz; + } +#endif + + float denom = c_xx * c_yy - c_xy * c_xy; - float denom = a * c - b * b; - float dL_da = 0, dL_db = 0, dL_dc = 0; float denom2inv = 1.0f / ((denom * denom) + 0.0000001f); if (denom2inv != 0) @@ -207,24 +250,25 @@ __global__ void computeCov2DCUDA(int P, // Gradients of loss w.r.t. entries of 2D covariance matrix, // given gradients of loss w.r.t. conic matrix (inverse covariance matrix). // e.g., dL / da = dL / d_conic_a * d_conic_a / d_a - dL_da = denom2inv * (-c * c * dL_dconic.x + 2 * b * c * dL_dconic.y + (denom - a * c) * dL_dconic.z); - dL_dc = denom2inv * (-a * a * dL_dconic.z + 2 * a * b * dL_dconic.y + (denom - a * c) * dL_dconic.x); - dL_db = denom2inv * 2 * (b * c * dL_dconic.x - (denom + 2 * b * b) * dL_dconic.y + a * b * dL_dconic.z); - - // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry, + + dL_dc_xx += denom2inv * (-c_yy * c_yy * dL_dconic.x + 2 * c_xy * c_yy * dL_dconic.y + (denom - c_xx * c_yy) * dL_dconic.z); + dL_dc_yy += denom2inv * (-c_xx * c_xx * dL_dconic.z + 2 * c_xx * c_xy * dL_dconic.y + (denom - c_xx * c_yy) * dL_dconic.x); + dL_dc_xy += denom2inv * 2 * (c_xy * c_yy * dL_dconic.x - (denom + 2 * c_xy * c_xy) * dL_dconic.y + c_xx * c_xy * dL_dconic.z); + + // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry, // given gradients w.r.t. 2D covariance matrix (diagonal). // cov2D = transpose(T) * transpose(Vrk) * T; - dL_dcov[6 * idx + 0] = (T[0][0] * T[0][0] * dL_da + T[0][0] * T[1][0] * dL_db + T[1][0] * T[1][0] * dL_dc); - dL_dcov[6 * idx + 3] = (T[0][1] * T[0][1] * dL_da + T[0][1] * T[1][1] * dL_db + T[1][1] * T[1][1] * dL_dc); - dL_dcov[6 * idx + 5] = (T[0][2] * T[0][2] * dL_da + T[0][2] * T[1][2] * dL_db + T[1][2] * T[1][2] * dL_dc); - - // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry, + dL_dcov[6 * idx + 0] = (T[0][0] * T[0][0] * dL_dc_xx + T[0][0] * T[1][0] * dL_dc_xy + T[1][0] * T[1][0] * dL_dc_yy); + dL_dcov[6 * idx + 3] = (T[0][1] * T[0][1] * dL_dc_xx + T[0][1] * T[1][1] * dL_dc_xy + T[1][1] * T[1][1] * dL_dc_yy); + dL_dcov[6 * idx + 5] = (T[0][2] * T[0][2] * dL_dc_xx + T[0][2] * T[1][2] * dL_dc_xy + T[1][2] * T[1][2] * dL_dc_yy); + + // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry, // given gradients w.r.t. 2D covariance matrix (off-diagonal). // Off-diagonal elements appear twice --> double the gradient. // cov2D = transpose(T) * transpose(Vrk) * T; - dL_dcov[6 * idx + 1] = 2 * T[0][0] * T[0][1] * dL_da + (T[0][0] * T[1][1] + T[0][1] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][1] * dL_dc; - dL_dcov[6 * idx + 2] = 2 * T[0][0] * T[0][2] * dL_da + (T[0][0] * T[1][2] + T[0][2] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][2] * dL_dc; - dL_dcov[6 * idx + 4] = 2 * T[0][2] * T[0][1] * dL_da + (T[0][1] * T[1][2] + T[0][2] * T[1][1]) * dL_db + 2 * T[1][1] * T[1][2] * dL_dc; + dL_dcov[6 * idx + 1] = 2 * T[0][0] * T[0][1] * dL_dc_xx + (T[0][0] * T[1][1] + T[0][1] * T[1][0]) * dL_dc_xy + 2 * T[1][0] * T[1][1] * dL_dc_yy; + dL_dcov[6 * idx + 2] = 2 * T[0][0] * T[0][2] * dL_dc_xx + (T[0][0] * T[1][2] + T[0][2] * T[1][0]) * dL_dc_xy + 2 * T[1][0] * T[1][2] * dL_dc_yy; + dL_dcov[6 * idx + 4] = 2 * T[0][2] * T[0][1] * dL_dc_xx + (T[0][1] * T[1][2] + T[0][2] * T[1][1]) * dL_dc_xy + 2 * T[1][1] * T[1][2] * dL_dc_yy; } else { @@ -234,18 +278,18 @@ __global__ void computeCov2DCUDA(int P, // Gradients of loss w.r.t. upper 2x3 portion of intermediate matrix T // cov2D = transpose(T) * transpose(Vrk) * T; - float dL_dT00 = 2 * (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_da + - (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_db; - float dL_dT01 = 2 * (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_da + - (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_db; - float dL_dT02 = 2 * (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_da + - (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_db; - float dL_dT10 = 2 * (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_dc + - (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_db; - float dL_dT11 = 2 * (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_dc + - (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_db; - float dL_dT12 = 2 * (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_dc + - (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_db; + float dL_dT00 = 2 * (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_dc_xx + + (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_dc_xy; + float dL_dT01 = 2 * (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_dc_xx + + (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_dc_xy; + float dL_dT02 = 2 * (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_dc_xx + + (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_dc_xy; + float dL_dT10 = 2 * (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_dc_yy + + (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_dc_xy; + float dL_dT11 = 2 * (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_dc_yy + + (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_dc_xy; + float dL_dT12 = 2 * (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_dc_yy + + (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_dc_xy; // Gradients of loss w.r.t. upper 3x2 non-zero entries of Jacobian matrix // T = W * J @@ -262,6 +306,10 @@ __global__ void computeCov2DCUDA(int P, float dL_dtx = x_grad_mul * -h_x * tz2 * dL_dJ02; float dL_dty = y_grad_mul * -h_y * tz2 * dL_dJ12; float dL_dtz = -h_x * tz2 * dL_dJ00 - h_y * tz2 * dL_dJ11 + (2 * h_x * t.x) * tz3 * dL_dJ02 + (2 * h_y * t.y) * tz3 * dL_dJ12; + // Account for inverse depth gradients + if (dL_dinvdepth) + dL_dtz -= dL_dinvdepth[idx] / (t.z * t.z); + // Account for transformation of mean to t // t = transformPoint4x3(mean, view_matrix); @@ -361,7 +409,8 @@ __global__ void preprocessCUDA( float* dL_dcov3D, float* dL_dsh, glm::vec3* dL_dscale, - glm::vec4* dL_drot) + glm::vec4* dL_drot, + float* dL_dopacity) { auto idx = cg::this_grid().thread_rank(); if (idx >= P || !(radii[idx] > 0)) @@ -406,13 +455,17 @@ renderCUDA( const float2* __restrict__ points_xy_image, const float4* __restrict__ conic_opacity, const float* __restrict__ colors, + const float* __restrict__ depths, const float* __restrict__ final_Ts, const uint32_t* __restrict__ n_contrib, const float* __restrict__ dL_dpixels, + const float* __restrict__ dL_invdepths, float3* __restrict__ dL_dmean2D, float4* __restrict__ dL_dconic2D, float* __restrict__ dL_dopacity, - float* __restrict__ dL_dcolors) + float* __restrict__ dL_dcolors, + float* __restrict__ dL_dinvdepths +) { // We rasterize again. Compute necessary block info. auto block = cg::this_thread_block(); @@ -435,6 +488,8 @@ renderCUDA( __shared__ float2 collected_xy[BLOCK_SIZE]; __shared__ float4 collected_conic_opacity[BLOCK_SIZE]; __shared__ float collected_colors[C * BLOCK_SIZE]; + __shared__ float collected_depths[BLOCK_SIZE]; + // In the forward, we stored the final value for T, the // product of all (1 - alpha) factors. @@ -448,12 +503,20 @@ renderCUDA( float accum_rec[C] = { 0 }; float dL_dpixel[C]; + float dL_invdepth; + float accum_invdepth_rec = 0; if (inside) + { for (int i = 0; i < C; i++) dL_dpixel[i] = dL_dpixels[i * H * W + pix_id]; + if(dL_invdepths) + dL_invdepth = dL_invdepths[pix_id]; + } float last_alpha = 0; float last_color[C] = { 0 }; + float last_invdepth = 0; + // Gradient of pixel coordinate w.r.t. normalized // screen-space viewport corrdinates (-1 to 1) @@ -475,6 +538,9 @@ renderCUDA( collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id]; for (int i = 0; i < C; i++) collected_colors[i * BLOCK_SIZE + block.thread_rank()] = colors[coll_id * C + i]; + + if(dL_invdepths) + collected_depths[block.thread_rank()] = depths[coll_id]; } block.sync(); @@ -522,6 +588,17 @@ renderCUDA( // many that were affected by this Gaussian. atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel); } + // Propagate gradients from inverse depth to alphaas and + // per Gaussian inverse depths + if (dL_dinvdepths) + { + const float invd = 1.f / collected_depths[j]; + accum_invdepth_rec = last_alpha * last_invdepth + (1.f - last_alpha) * accum_invdepth_rec; + last_invdepth = invd; + dL_dalpha += (invd - accum_invdepth_rec) * dL_invdepth; + atomicAdd(&(dL_dinvdepths[global_id]), dchannel_dcolor * dL_invdepth); + } + dL_dalpha *= T; // Update last alpha (to be used in the next iteration) last_alpha = alpha; @@ -562,6 +639,7 @@ void BACKWARD::preprocess( const int* radii, const float* shs, const bool* clamped, + const float* opacities, const glm::vec3* scales, const glm::vec4* rotations, const float scale_modifier, @@ -573,6 +651,8 @@ void BACKWARD::preprocess( const glm::vec3* campos, const float3* dL_dmean2D, const float* dL_dconic, + const float* dL_dinvdepth, + float* dL_dopacity, glm::vec3* dL_dmean3D, float* dL_dcolor, float* dL_dcov3D, @@ -594,7 +674,10 @@ void BACKWARD::preprocess( tan_fovx, tan_fovy, viewmatrix, + opacities, dL_dconic, + dL_dopacity, + dL_dinvdepth, (float3*)dL_dmean3D, dL_dcov3D); @@ -618,7 +701,8 @@ void BACKWARD::preprocess( dL_dcov3D, dL_dsh, dL_dscale, - dL_drot); + dL_drot, + dL_dopacity); } void BACKWARD::render( @@ -630,13 +714,16 @@ void BACKWARD::render( const float2* means2D, const float4* conic_opacity, const float* colors, + const float* depths, const float* final_Ts, const uint32_t* n_contrib, const float* dL_dpixels, + const float* dL_invdepths, float3* dL_dmean2D, float4* dL_dconic2D, float* dL_dopacity, - float* dL_dcolors) + float* dL_dcolors, + float* dL_dinvdepths) { renderCUDA << > >( ranges, @@ -646,12 +733,15 @@ void BACKWARD::render( means2D, conic_opacity, colors, + depths, final_Ts, n_contrib, dL_dpixels, + dL_invdepths, dL_dmean2D, dL_dconic2D, dL_dopacity, - dL_dcolors + dL_dcolors, + dL_dinvdepths ); -} \ No newline at end of file +} diff --git a/cuda_rasterizer/backward.h b/cuda_rasterizer/backward.h index 93dd2e4b..a9c9772f 100644 --- a/cuda_rasterizer/backward.h +++ b/cuda_rasterizer/backward.h @@ -29,13 +29,16 @@ namespace BACKWARD const float2* means2D, const float4* conic_opacity, const float* colors, + const float* depths, const float* final_Ts, const uint32_t* n_contrib, const float* dL_dpixels, + const float* dL_invdepths, float3* dL_dmean2D, float4* dL_dconic2D, float* dL_dopacity, - float* dL_dcolors); + float* dL_dcolors, + float* dL_dinvdepths); void preprocess( int P, int D, int M, @@ -43,6 +46,7 @@ namespace BACKWARD const int* radii, const float* shs, const bool* clamped, + const float* opacities, const glm::vec3* scales, const glm::vec4* rotations, const float scale_modifier, @@ -54,6 +58,8 @@ namespace BACKWARD const glm::vec3* campos, const float3* dL_dmean2D, const float* dL_dconics, + const float* dL_dinvdepth, + float* dL_dopacity, glm::vec3* dL_dmeans, float* dL_dcolor, float* dL_dcov3D, @@ -62,4 +68,4 @@ namespace BACKWARD glm::vec4* dL_drot); } -#endif \ No newline at end of file +#endif diff --git a/cuda_rasterizer/forward.cu b/cuda_rasterizer/forward.cu index c419a328..847a88b3 100644 --- a/cuda_rasterizer/forward.cu +++ b/cuda_rasterizer/forward.cu @@ -105,10 +105,6 @@ __device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, glm::mat3 cov = glm::transpose(T) * glm::transpose(Vrk) * T; - // Apply low-pass filter: every Gaussian should be at least - // one pixel wide/high. Discard 3rd row and column. - cov[0][0] += 0.3f; - cov[1][1] += 0.3f; return { float(cov[0][0]), float(cov[0][1]), float(cov[1][1]) }; } @@ -215,8 +211,19 @@ __global__ void preprocessCUDA(int P, int D, int M, // Compute 2D screen-space covariance matrix float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix); + constexpr float h_var = 0.3f; + const float det_cov = cov.x * cov.z - cov.y * cov.y; + cov.x += h_var; + cov.z += h_var; + const float det_cov_plus_h_cov = cov.x * cov.z - cov.y * cov.y; + +#ifdef DGR_FIX_AA + const float h_convolution_scaling = sqrt(max(0.000025f, det_cov / det_cov_plus_h_cov)); // max for numerical stability +#endif + // Invert covariance (EWA algorithm) - float det = (cov.x * cov.z - cov.y * cov.y); + const float det = det_cov_plus_h_cov; + if (det == 0.0f) return; float det_inv = 1.f / det; @@ -251,7 +258,14 @@ __global__ void preprocessCUDA(int P, int D, int M, radii[idx] = my_radius; points_xy_image[idx] = point_image; // Inverse 2D covariance and opacity neatly pack into one float4 - conic_opacity[idx] = { conic.x, conic.y, conic.z, opacities[idx] }; + float opacity = opacities[idx]; + +#ifdef DGR_FIX_AA + conic_opacity[idx] = { conic.x, conic.y, conic.z, opacity * h_convolution_scaling }; +#else + conic_opacity[idx] = { conic.x, conic.y, conic.z, opacity }; +#endif + tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x); } @@ -270,7 +284,9 @@ renderCUDA( float* __restrict__ final_T, uint32_t* __restrict__ n_contrib, const float* __restrict__ bg_color, - float* __restrict__ out_color) + float* __restrict__ out_color, + const float* __restrict__ depths, + float* __restrict__ invdepth) { // Identify current tile and associated min/max pixel range. auto block = cg::this_thread_block(); @@ -302,6 +318,8 @@ renderCUDA( uint32_t last_contributor = 0; float C[CHANNELS] = { 0 }; + float expected_invdepth = 0.0f; + // Iterate over batches until all done or range is complete for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) { @@ -354,6 +372,9 @@ renderCUDA( for (int ch = 0; ch < CHANNELS; ch++) C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T; + if(invdepth) + expected_invdepth += (1 / depths[collected_id[j]]) * alpha * T; + T = test_T; // Keep track of last range entry to update this @@ -370,6 +391,9 @@ renderCUDA( n_contrib[pix_id] = last_contributor; for (int ch = 0; ch < CHANNELS; ch++) out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch]; + + if (invdepth) + invdepth[pix_id] = expected_invdepth;// 1. / (expected_depth + T * 1e3); } } @@ -384,7 +408,9 @@ void FORWARD::render( float* final_T, uint32_t* n_contrib, const float* bg_color, - float* out_color) + float* out_color, + float* depths, + float* depth) { renderCUDA << > > ( ranges, @@ -396,7 +422,9 @@ void FORWARD::render( final_T, n_contrib, bg_color, - out_color); + out_color, + depths, + depth); } void FORWARD::preprocess(int P, int D, int M, @@ -452,4 +480,4 @@ void FORWARD::preprocess(int P, int D, int M, tiles_touched, prefiltered ); -} \ No newline at end of file +} diff --git a/cuda_rasterizer/forward.h b/cuda_rasterizer/forward.h index 3c11cb91..7eebd940 100644 --- a/cuda_rasterizer/forward.h +++ b/cuda_rasterizer/forward.h @@ -59,8 +59,10 @@ namespace FORWARD float* final_T, uint32_t* n_contrib, const float* bg_color, - float* out_color); + float* out_color, + float* depths, + float* depth); } -#endif \ No newline at end of file +#endif diff --git a/cuda_rasterizer/rasterizer.h b/cuda_rasterizer/rasterizer.h index 81544ef6..6fab9d82 100644 --- a/cuda_rasterizer/rasterizer.h +++ b/cuda_rasterizer/rasterizer.h @@ -49,6 +49,7 @@ namespace CudaRasterizer const float tan_fovx, float tan_fovy, const bool prefiltered, float* out_color, + float* depth, int* radii = nullptr, bool debug = false); @@ -59,6 +60,7 @@ namespace CudaRasterizer const float* means3D, const float* shs, const float* colors_precomp, + const float* opacities, const float* scales, const float scale_modifier, const float* rotations, @@ -72,10 +74,12 @@ namespace CudaRasterizer char* binning_buffer, char* image_buffer, const float* dL_dpix, + const float* dL_invdepths, float* dL_dmean2D, float* dL_dconic, float* dL_dopacity, float* dL_dcolor, + float* dL_dinvdepth, float* dL_dmean3D, float* dL_dcov3D, float* dL_dsh, @@ -85,4 +89,4 @@ namespace CudaRasterizer }; }; -#endif \ No newline at end of file +#endif diff --git a/cuda_rasterizer/rasterizer_impl.cu b/cuda_rasterizer/rasterizer_impl.cu index f8782ac4..4969ada1 100644 --- a/cuda_rasterizer/rasterizer_impl.cu +++ b/cuda_rasterizer/rasterizer_impl.cu @@ -216,6 +216,7 @@ int CudaRasterizer::Rasterizer::forward( const float tan_fovx, float tan_fovy, const bool prefiltered, float* out_color, + float* depth, int* radii, bool debug) { @@ -330,7 +331,9 @@ int CudaRasterizer::Rasterizer::forward( imgState.accum_alpha, imgState.n_contrib, background, - out_color), debug) + out_color, + geomState.depths, + depth), debug) return num_rendered; } @@ -344,6 +347,7 @@ void CudaRasterizer::Rasterizer::backward( const float* means3D, const float* shs, const float* colors_precomp, + const float* opacities, const float* scales, const float scale_modifier, const float* rotations, @@ -357,10 +361,12 @@ void CudaRasterizer::Rasterizer::backward( char* binning_buffer, char* img_buffer, const float* dL_dpix, + const float* dL_invdepths, float* dL_dmean2D, float* dL_dconic, float* dL_dopacity, float* dL_dcolor, + float* dL_dinvdepth, float* dL_dmean3D, float* dL_dcov3D, float* dL_dsh, @@ -397,13 +403,16 @@ void CudaRasterizer::Rasterizer::backward( geomState.means2D, geomState.conic_opacity, color_ptr, + geomState.depths, imgState.accum_alpha, imgState.n_contrib, dL_dpix, + dL_invdepths, (float3*)dL_dmean2D, (float4*)dL_dconic, dL_dopacity, - dL_dcolor), debug) + dL_dcolor, + dL_dinvdepth), debug); // Take care of the rest of preprocessing. Was the precomputed covariance // given to us or a scales/rot pair? If precomputed, pass that. If not, @@ -414,6 +423,7 @@ void CudaRasterizer::Rasterizer::backward( radii, shs, geomState.clamped, + opacities, (glm::vec3*)scales, (glm::vec4*)rotations, scale_modifier, @@ -425,10 +435,12 @@ void CudaRasterizer::Rasterizer::backward( (glm::vec3*)campos, (float3*)dL_dmean2D, dL_dconic, + dL_dinvdepth, + dL_dopacity, (glm::vec3*)dL_dmean3D, dL_dcolor, dL_dcov3D, dL_dsh, (glm::vec3*)dL_dscale, - (glm::vec4*)dL_drot), debug) -} \ No newline at end of file + (glm::vec4*)dL_drot), debug); +} diff --git a/diff_gaussian_rasterization/__init__.py b/diff_gaussian_rasterization/__init__.py index bbef37d1..0cd4a093 100644 --- a/diff_gaussian_rasterization/__init__.py +++ b/diff_gaussian_rasterization/__init__.py @@ -80,36 +80,28 @@ def forward( ) # Invoke C++/CUDA rasterizer - if raster_settings.debug: - cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted - try: - num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) - except Exception as ex: - torch.save(cpu_args, "snapshot_fw.dump") - print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.") - raise ex - else: - num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) + num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer, invdepths = _C.rasterize_gaussians(*args) # Keep relevant tensors for backward ctx.raster_settings = raster_settings ctx.num_rendered = num_rendered - ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer) - return color, radii + ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, opacities, geomBuffer, binningBuffer, imgBuffer) + return color, radii, invdepths @staticmethod - def backward(ctx, grad_out_color, _): + def backward(ctx, grad_out_color, _, grad_out_depth): # Restore necessary values from context num_rendered = ctx.num_rendered raster_settings = ctx.raster_settings - colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors + colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, opacities, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors # Restructure args as C++ method expects them args = (raster_settings.bg, means3D, radii, colors_precomp, + opacities, scales, rotations, raster_settings.scale_modifier, @@ -118,7 +110,8 @@ def backward(ctx, grad_out_color, _): raster_settings.projmatrix, raster_settings.tanfovx, raster_settings.tanfovy, - grad_out_color, + grad_out_color, + grad_out_depth, sh, raster_settings.sh_degree, raster_settings.campos, @@ -129,16 +122,7 @@ def backward(ctx, grad_out_color, _): raster_settings.debug) # Compute gradients for relevant tensors by invoking backward method - if raster_settings.debug: - cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted - try: - grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) - except Exception as ex: - torch.save(cpu_args, "snapshot_bw.dump") - print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n") - raise ex - else: - grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) + grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) grads = ( grad_means3D, diff --git a/rasterize_points.cu b/rasterize_points.cu index ddc5cf8b..82abd6fc 100644 --- a/rasterize_points.cu +++ b/rasterize_points.cu @@ -32,7 +32,7 @@ std::function resizeFunctional(torch::Tensor& t) { return lambda; } -std::tuple +std::tuple RasterizeGaussiansCUDA( const torch::Tensor& background, const torch::Tensor& means3D, @@ -66,6 +66,12 @@ RasterizeGaussiansCUDA( auto float_opts = means3D.options().dtype(torch::kFloat32); torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts); + torch::Tensor out_invdepth = torch::full({0, H, W}, 0.0, float_opts); + float* out_invdepthptr = nullptr; + + out_invdepth = torch::full({1, H, W}, 0.0, float_opts).contiguous(); + out_invdepthptr = out_invdepth.data(); + torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); torch::Device device(torch::kCUDA); @@ -108,10 +114,11 @@ RasterizeGaussiansCUDA( tan_fovy, prefiltered, out_color.contiguous().data(), + out_invdepthptr, radii.contiguous().data(), debug); } - return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer); + return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer, out_invdepth); } std::tuple @@ -120,6 +127,7 @@ std::tuple(); + dL_dout_invdepthptr = dL_dout_invdepth.data(); + } + if(P != 0) { CudaRasterizer::Rasterizer::backward(P, degree, M, R, @@ -166,6 +186,7 @@ std::tuple(), sh.contiguous().data(), colors.contiguous().data(), + opacities.contiguous().data(), scales.data_ptr(), scale_modifier, rotations.data_ptr(), @@ -180,10 +201,12 @@ std::tuple(binningBuffer.contiguous().data_ptr()), reinterpret_cast(imageBuffer.contiguous().data_ptr()), dL_dout_color.contiguous().data(), + dL_dout_invdepthptr, dL_dmeans2D.contiguous().data(), dL_dconic.contiguous().data(), dL_dopacity.contiguous().data(), dL_dcolors.contiguous().data(), + dL_dinvdepthsptr, dL_dmeans3D.contiguous().data(), dL_dcov3D.contiguous().data(), dL_dsh.contiguous().data(), @@ -214,4 +237,4 @@ torch::Tensor markVisible( } return present; -} \ No newline at end of file +} diff --git a/rasterize_points.h b/rasterize_points.h index 9023d994..75574363 100644 --- a/rasterize_points.h +++ b/rasterize_points.h @@ -15,7 +15,7 @@ #include #include -std::tuple +std::tuple RasterizeGaussiansCUDA( const torch::Tensor& background, const torch::Tensor& means3D, @@ -44,6 +44,7 @@ std::tuple Date: Fri, 6 Sep 2024 16:31:26 +0200 Subject: [PATCH 2/5] toggle antialiasing --- cuda_rasterizer/auxiliary.h | 1 - cuda_rasterizer/backward.cu | 70 ++++++++++++++----------- cuda_rasterizer/backward.h | 3 +- cuda_rasterizer/forward.cu | 21 ++++---- cuda_rasterizer/forward.h | 3 +- cuda_rasterizer/rasterizer.h | 2 + cuda_rasterizer/rasterizer_impl.cu | 8 ++- diff_gaussian_rasterization/__init__.py | 3 ++ rasterize_points.cu | 6 ++- rasterize_points.h | 2 + 10 files changed, 71 insertions(+), 48 deletions(-) diff --git a/cuda_rasterizer/auxiliary.h b/cuda_rasterizer/auxiliary.h index 7f6b7548..aaf21915 100644 --- a/cuda_rasterizer/auxiliary.h +++ b/cuda_rasterizer/auxiliary.h @@ -17,7 +17,6 @@ #define BLOCK_SIZE (BLOCK_X * BLOCK_Y) #define NUM_WARPS (BLOCK_SIZE/32) -#define DGR_FIX_AA // Spherical harmonics coefficients __device__ const float SH_C0 = 0.28209479177387814f; __device__ const float SH_C1 = 0.4886025119029199f; diff --git a/cuda_rasterizer/backward.cu b/cuda_rasterizer/backward.cu index 7a609ac8..006b1ef7 100644 --- a/cuda_rasterizer/backward.cu +++ b/cuda_rasterizer/backward.cu @@ -156,7 +156,8 @@ __global__ void computeCov2DCUDA(int P, float* dL_dopacity, const float* dL_dinvdepth, float3* dL_dmeans, - float* dL_dcov) + float* dL_dcov, + bool antialiasing) { auto idx = cg::this_grid().thread_rank(); if (idx >= P || !(radii[idx] > 0)) @@ -205,41 +206,44 @@ __global__ void computeCov2DCUDA(int P, float c_yy = cov2D[1][1]; constexpr float h_var = 0.3f; -#ifdef DGR_FIX_AA - const float det_cov = c_xx * c_yy - c_xy * c_xy; - c_xx += h_var; - c_yy += h_var; - const float det_cov_plus_h_cov = c_xx * c_yy - c_xy * c_xy; - const float h_convolution_scaling = sqrt(max(0.000025f, det_cov / det_cov_plus_h_cov)); // max for numerical stability - const float dL_dopacity_v = dL_dopacity[idx]; - const float d_h_convolution_scaling = dL_dopacity_v * opacities[idx]; - dL_dopacity[idx] = dL_dopacity_v * h_convolution_scaling; - const float d_inside_root = (det_cov / det_cov_plus_h_cov) <= 0.000025f ? 0.f : d_h_convolution_scaling / (2 * h_convolution_scaling); -#else - c_xx += h_var; - c_yy += h_var; -#endif + float d_inside_root = 0.f; + if(antialiasing) + { + const float det_cov = c_xx * c_yy - c_xy * c_xy; + c_xx += h_var; + c_yy += h_var; + const float det_cov_plus_h_cov = c_xx * c_yy - c_xy * c_xy; + const float h_convolution_scaling = sqrt(max(0.000025f, det_cov / det_cov_plus_h_cov)); // max for numerical stability + const float dL_dopacity_v = dL_dopacity[idx]; + const float d_h_convolution_scaling = dL_dopacity_v * opacities[idx]; + dL_dopacity[idx] = dL_dopacity_v * h_convolution_scaling; + d_inside_root = (det_cov / det_cov_plus_h_cov) <= 0.000025f ? 0.f : d_h_convolution_scaling / (2 * h_convolution_scaling); + } + else + { + c_xx += h_var; + c_yy += h_var; + } float dL_dc_xx = 0; float dL_dc_xy = 0; float dL_dc_yy = 0; -#ifdef DGR_FIX_AA + if(antialiasing) { - // https://www.wolframalpha.com/input?i=d+%28%28x*y+-+z%5E2%29%2F%28%28x%2Bw%29*%28y%2Bw%29+-+z%5E2%29%29+%2Fdx - // https://www.wolframalpha.com/input?i=d+%28%28x*y+-+z%5E2%29%2F%28%28x%2Bw%29*%28y%2Bw%29+-+z%5E2%29%29+%2Fdz - const float x = c_xx; - const float y = c_yy; - const float z = c_xy; - const float w = h_var; - const float denom_f = d_inside_root / sq(w * w + w * (x + y) + x * y - z * z); - const float dL_dx = w * (w * y + y * y + z * z) * denom_f; - const float dL_dy = w * (w * x + x * x + z * z) * denom_f; - const float dL_dz = -2.f * w * z * (w + x + y) * denom_f; - dL_dc_xx = dL_dx; - dL_dc_yy = dL_dy; - dL_dc_xy = dL_dz; + // https://www.wolframalpha.com/input?i=d+%28%28x*y+-+z%5E2%29%2F%28%28x%2Bw%29*%28y%2Bw%29+-+z%5E2%29%29+%2Fdx + // https://www.wolframalpha.com/input?i=d+%28%28x*y+-+z%5E2%29%2F%28%28x%2Bw%29*%28y%2Bw%29+-+z%5E2%29%29+%2Fdz + const float x = c_xx; + const float y = c_yy; + const float z = c_xy; + const float w = h_var; + const float denom_f = d_inside_root / sq(w * w + w * (x + y) + x * y - z * z); + const float dL_dx = w * (w * y + y * y + z * z) * denom_f; + const float dL_dy = w * (w * x + x * x + z * z) * denom_f; + const float dL_dz = -2.f * w * z * (w + x + y) * denom_f; + dL_dc_xx = dL_dx; + dL_dc_yy = dL_dy; + dL_dc_xy = dL_dz; } -#endif float denom = c_xx * c_yy - c_xy * c_xy; @@ -658,7 +662,8 @@ void BACKWARD::preprocess( float* dL_dcov3D, float* dL_dsh, glm::vec3* dL_dscale, - glm::vec4* dL_drot) + glm::vec4* dL_drot, + bool antialiasing) { // Propagate gradients for the path of 2D conic matrix computation. // Somewhat long, thus it is its own kernel rather than being part of @@ -679,7 +684,8 @@ void BACKWARD::preprocess( dL_dopacity, dL_dinvdepth, (float3*)dL_dmean3D, - dL_dcov3D); + dL_dcov3D, + antialiasing); // Propagate gradients for remaining steps: finish 3D mean gradients, // propagate color gradients to SH (if desireD), propagate 3D covariance diff --git a/cuda_rasterizer/backward.h b/cuda_rasterizer/backward.h index a9c9772f..4d02560f 100644 --- a/cuda_rasterizer/backward.h +++ b/cuda_rasterizer/backward.h @@ -65,7 +65,8 @@ namespace BACKWARD float* dL_dcov3D, float* dL_dsh, glm::vec3* dL_dscale, - glm::vec4* dL_drot); + glm::vec4* dL_drot, + bool antialiasing); } #endif diff --git a/cuda_rasterizer/forward.cu b/cuda_rasterizer/forward.cu index 847a88b3..c5e01ddc 100644 --- a/cuda_rasterizer/forward.cu +++ b/cuda_rasterizer/forward.cu @@ -173,7 +173,8 @@ __global__ void preprocessCUDA(int P, int D, int M, float4* conic_opacity, const dim3 grid, uint32_t* tiles_touched, - bool prefiltered) + bool prefiltered, + bool antialiasing) { auto idx = cg::this_grid().thread_rank(); if (idx >= P) @@ -216,10 +217,10 @@ __global__ void preprocessCUDA(int P, int D, int M, cov.x += h_var; cov.z += h_var; const float det_cov_plus_h_cov = cov.x * cov.z - cov.y * cov.y; + float h_convolution_scaling = 1.0f; -#ifdef DGR_FIX_AA - const float h_convolution_scaling = sqrt(max(0.000025f, det_cov / det_cov_plus_h_cov)); // max for numerical stability -#endif + if(antialiasing) + h_convolution_scaling = sqrt(max(0.000025f, det_cov / det_cov_plus_h_cov)); // max for numerical stability // Invert covariance (EWA algorithm) const float det = det_cov_plus_h_cov; @@ -260,11 +261,9 @@ __global__ void preprocessCUDA(int P, int D, int M, // Inverse 2D covariance and opacity neatly pack into one float4 float opacity = opacities[idx]; -#ifdef DGR_FIX_AA + conic_opacity[idx] = { conic.x, conic.y, conic.z, opacity * h_convolution_scaling }; -#else - conic_opacity[idx] = { conic.x, conic.y, conic.z, opacity }; -#endif + tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x); } @@ -451,7 +450,8 @@ void FORWARD::preprocess(int P, int D, int M, float4* conic_opacity, const dim3 grid, uint32_t* tiles_touched, - bool prefiltered) + bool prefiltered, + bool antialiasing) { preprocessCUDA << <(P + 255) / 256, 256 >> > ( P, D, M, @@ -478,6 +478,7 @@ void FORWARD::preprocess(int P, int D, int M, conic_opacity, grid, tiles_touched, - prefiltered + prefiltered, + antialiasing ); } diff --git a/cuda_rasterizer/forward.h b/cuda_rasterizer/forward.h index 7eebd940..5b18005f 100644 --- a/cuda_rasterizer/forward.h +++ b/cuda_rasterizer/forward.h @@ -45,7 +45,8 @@ namespace FORWARD float4* conic_opacity, const dim3 grid, uint32_t* tiles_touched, - bool prefiltered); + bool prefiltered, + bool antialiasing); // Main rasterization method. void render( diff --git a/cuda_rasterizer/rasterizer.h b/cuda_rasterizer/rasterizer.h index 6fab9d82..64afdee9 100644 --- a/cuda_rasterizer/rasterizer.h +++ b/cuda_rasterizer/rasterizer.h @@ -50,6 +50,7 @@ namespace CudaRasterizer const bool prefiltered, float* out_color, float* depth, + bool antialiasing, int* radii = nullptr, bool debug = false); @@ -85,6 +86,7 @@ namespace CudaRasterizer float* dL_dsh, float* dL_dscale, float* dL_drot, + bool antialiasing, bool debug); }; }; diff --git a/cuda_rasterizer/rasterizer_impl.cu b/cuda_rasterizer/rasterizer_impl.cu index 4969ada1..41c4ed3a 100644 --- a/cuda_rasterizer/rasterizer_impl.cu +++ b/cuda_rasterizer/rasterizer_impl.cu @@ -217,6 +217,7 @@ int CudaRasterizer::Rasterizer::forward( const bool prefiltered, float* out_color, float* depth, + bool antialiasing, int* radii, bool debug) { @@ -270,7 +271,8 @@ int CudaRasterizer::Rasterizer::forward( geomState.conic_opacity, tile_grid, geomState.tiles_touched, - prefiltered + prefiltered, + antialiasing ), debug) // Compute prefix sum over full list of touched tile counts by Gaussians @@ -372,6 +374,7 @@ void CudaRasterizer::Rasterizer::backward( float* dL_dsh, float* dL_dscale, float* dL_drot, + bool antialiasing, bool debug) { GeometryState geomState = GeometryState::fromChunk(geom_buffer, P); @@ -442,5 +445,6 @@ void CudaRasterizer::Rasterizer::backward( dL_dcov3D, dL_dsh, (glm::vec3*)dL_dscale, - (glm::vec4*)dL_drot), debug); + (glm::vec4*)dL_drot, + antialiasing), debug); } diff --git a/diff_gaussian_rasterization/__init__.py b/diff_gaussian_rasterization/__init__.py index 0cd4a093..7f228cec 100644 --- a/diff_gaussian_rasterization/__init__.py +++ b/diff_gaussian_rasterization/__init__.py @@ -76,6 +76,7 @@ def forward( raster_settings.sh_degree, raster_settings.campos, raster_settings.prefiltered, + raster_settings.antialiasing, raster_settings.debug ) @@ -119,6 +120,7 @@ def backward(ctx, grad_out_color, _, grad_out_depth): num_rendered, binningBuffer, imgBuffer, + raster_settings.antialiasing, raster_settings.debug) # Compute gradients for relevant tensors by invoking backward method @@ -151,6 +153,7 @@ class GaussianRasterizationSettings(NamedTuple): campos : torch.Tensor prefiltered : bool debug : bool + antialiasing : bool class GaussianRasterizer(nn.Module): def __init__(self, raster_settings): diff --git a/rasterize_points.cu b/rasterize_points.cu index 82abd6fc..e625c19e 100644 --- a/rasterize_points.cu +++ b/rasterize_points.cu @@ -52,6 +52,7 @@ RasterizeGaussiansCUDA( const int degree, const torch::Tensor& campos, const bool prefiltered, + const bool antialiasing, const bool debug) { if (means3D.ndimension() != 2 || means3D.size(1) != 3) { @@ -115,6 +116,7 @@ RasterizeGaussiansCUDA( prefiltered, out_color.contiguous().data(), out_invdepthptr, + antialiasing, radii.contiguous().data(), debug); } @@ -145,7 +147,8 @@ std::tuple(), dL_dscales.contiguous().data(), dL_drotations.contiguous().data(), + antialiasing, debug); } diff --git a/rasterize_points.h b/rasterize_points.h index 75574363..82cbd4f2 100644 --- a/rasterize_points.h +++ b/rasterize_points.h @@ -35,6 +35,7 @@ RasterizeGaussiansCUDA( const int degree, const torch::Tensor& campos, const bool prefiltered, + const bool antialiasing, const bool debug); std::tuple @@ -61,6 +62,7 @@ std::tuple Date: Mon, 2 Jun 2025 11:25:45 +0300 Subject: [PATCH 3/5] Fix for CUDA 12.8 --- cuda_rasterizer/rasterizer_impl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/cuda_rasterizer/rasterizer_impl.h b/cuda_rasterizer/rasterizer_impl.h index bc3f0ece..9cb61543 100644 --- a/cuda_rasterizer/rasterizer_impl.h +++ b/cuda_rasterizer/rasterizer_impl.h @@ -11,6 +11,7 @@ #pragma once +#include #include #include #include "rasterizer.h" From 19acebb9196db95140790a716ec6b2a5b65c1814 Mon Sep 17 00:00:00 2001 From: Oleg Semery Date: Thu, 24 Jul 2025 18:44:30 +0300 Subject: [PATCH 4/5] Update setup-scripts --- .gitignore | 227 +++++++++++++++++++++++- diff_gaussian_rasterization/__init__.py | 2 + pyproject.toml | 40 +++++ setup.py | 67 ++++--- 4 files changed, 309 insertions(+), 27 deletions(-) create mode 100644 pyproject.toml diff --git a/.gitignore b/.gitignore index 1e1c4ca8..e6803673 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,228 @@ +# PyCharm ### +.idea/ + +# Visual Studio ### +Release/ +Debug/ +.vs/ +*.VC.db +*.sdf +*.suo +*.opendb +*.psess +*.vsp +*.vspx +*.sln +*.pyproj +x64 + +# R ### +.Rhistory + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python build/ -diff_gaussian_rasterization.egg-info/ +develop-eggs/ dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock +#poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +#pdm.lock +#pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +#pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Cursor +# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to +# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data +# refer to https://docs.cursor.com/context/ignore-files +.cursorignore +.cursorindexingignore + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ diff --git a/diff_gaussian_rasterization/__init__.py b/diff_gaussian_rasterization/__init__.py index 7f228cec..51520250 100644 --- a/diff_gaussian_rasterization/__init__.py +++ b/diff_gaussian_rasterization/__init__.py @@ -9,6 +9,8 @@ # For inquiries contact george.drettakis@inria.fr # +__version__ = "0.0.1" + from typing import NamedTuple import torch.nn as nn import torch diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..7e5d8407 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,40 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel", "torch>=1.8"] +build-backend = "setuptools.build_meta" + +[project] +name = "diff_gaussian_rasterization" +dynamic = ["version"] +description = "Differentiable Gaussian Rasterization for Python with CUDA support" +readme = "README.md" +requires-python = ">=3.10" +license-files = ["LICENSE.md"] +authors = [ + { name = "George Drettakis", email = "george.drettakis@inria.fr" }, + { name = "Oleg Sémery", email = "osemery@gmail.com" }, +] +urls = { Homepage = "https://github.com/osmr/diff-gaussian-rasterization" } +keywords = ["gaussian splatting", "3dgs", "3d reconstruction", "image processing", "CUDA"] +classifiers = [ + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Science/Research', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Topic :: Scientific/Engineering :: Image Processing', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', +] +dependencies = [ + "torch>=1.8", +] + +[tool.wheel] +universal = true + +[tool.setuptools.dynamic] +version = {attr = "diff_gaussian_rasterization.__version__"} + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +exclude = ["others", "*.others", "others.*", "*.others.*"] diff --git a/setup.py b/setup.py index bb7220d2..99781735 100644 --- a/setup.py +++ b/setup.py @@ -1,34 +1,49 @@ -# -# Copyright (C) 2023, Inria -# GRAPHDECO research group, https://team.inria.fr/graphdeco -# All rights reserved. -# -# This software is free for non-commercial, research and evaluation use -# under the terms of the LICENSE.md file. -# -# For inquiries contact george.drettakis@inria.fr -# - -from setuptools import setup -from torch.utils.cpp_extension import CUDAExtension, BuildExtension import os -os.path.dirname(os.path.abspath(__file__)) +from setuptools import setup +from torch.cuda import (is_available as cuda_is_available, + current_device as cuda_current_device, + get_device_capability as cuda_get_device_capability) +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +nvcc_args = [ + "-O3", + "-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/"), +] + +if cuda_is_available(): + try: + device_id = cuda_current_device() + compute_capability = cuda_get_device_capability(device_id) + sm_version = "".join(map(str, compute_capability)) + nvcc_args.append(f"-gencode=arch=compute_{sm_version},code=sm_{sm_version}") + except Exception as e: + raise RuntimeError(f"Failed during GPU architecture detection: {e}.") +else: + sm_versions = [ + "75", # Turing (GTX 16-series, RTX 20-series, Tesla T4) + "80", # Ampere (A100) + "86", # Ampere (RTX 30-series) + "89", # Ada Lovelace (RTX 40-series, L4, L40) + "90", # Hopper (H100, H200) + "100", # Blackwell (B100) + "101", # Blackwell (B200) + ] + arch_keys = [f"-gencode=arch=compute_{sm_version},code=sm_{sm_version}" for sm_version in sm_versions] + nvcc_args.extend(arch_keys) setup( - name="diff_gaussian_rasterization", - packages=['diff_gaussian_rasterization'], ext_modules=[ CUDAExtension( name="diff_gaussian_rasterization._C", sources=[ - "cuda_rasterizer/rasterizer_impl.cu", - "cuda_rasterizer/forward.cu", - "cuda_rasterizer/backward.cu", - "rasterize_points.cu", - "ext.cpp"], - extra_compile_args={"nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]}) - ], - cmdclass={ - 'build_ext': BuildExtension - } + "cuda_rasterizer/rasterizer_impl.cu", + "cuda_rasterizer/forward.cu", + "cuda_rasterizer/backward.cu", + "rasterize_points.cu", + "ext.cpp", + ], + extra_compile_args={'cxx': ['-O3'], 'nvcc': nvcc_args}, + ), + ], + cmdclass={'build_ext': BuildExtension}, ) From 86c2e2037c159ff79bce45b974a0549181d7a6c2 Mon Sep 17 00:00:00 2001 From: Oleg Semery Date: Mon, 28 Jul 2025 18:42:10 +0300 Subject: [PATCH 5/5] Refactoring --- cuda_rasterizer/auxiliary.h | 78 ++++---- cuda_rasterizer/backward.cu | 262 +++++++++++++-------------- cuda_rasterizer/backward.h | 113 ++++++------ cuda_rasterizer/config.h | 17 +- cuda_rasterizer/forward.cu | 281 +++++++++++++++-------------- cuda_rasterizer/forward.h | 101 +++++------ cuda_rasterizer/rasterizer.h | 38 ++-- cuda_rasterizer/rasterizer_impl.cu | 237 ++++++++++++------------ cuda_rasterizer/rasterizer_impl.h | 26 +-- ext.cpp | 11 -- rasterize_points.cu | 271 +++++++++++++--------------- rasterize_points.h | 113 ++++++------ 12 files changed, 732 insertions(+), 816 deletions(-) diff --git a/cuda_rasterizer/auxiliary.h b/cuda_rasterizer/auxiliary.h index aaf21915..942a10c2 100644 --- a/cuda_rasterizer/auxiliary.h +++ b/cuda_rasterizer/auxiliary.h @@ -1,16 +1,5 @@ -/* - * Copyright (C) 2023, Inria - * GRAPHDECO research group, https://team.inria.fr/graphdeco - * All rights reserved. - * - * This software is free for non-commercial, research and evaluation use - * under the terms of the LICENSE.md file. - * - * For inquiries contact george.drettakis@inria.fr - */ - -#ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED -#define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED +#ifndef CUDA_RASTERIZER_AUXILIARY_H +#define CUDA_RASTERIZER_AUXILIARY_H #include "config.h" #include "stdio.h" @@ -37,13 +26,16 @@ __device__ const float SH_C3[] = { -0.5900435899266435f }; -__forceinline__ __device__ float ndc2Pix(float v, int S) -{ +__forceinline__ __device__ float ndc2Pix(const float v, + const int S) { return ((v + 1.0) * S - 1.0) * 0.5; } -__forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid) -{ +__forceinline__ __device__ void getRect(const float2 p, + const int max_radius, + uint2& rect_min, + uint2& rect_max, + const dim3 grid) { rect_min = { min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))), min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y))) @@ -54,8 +46,11 @@ __forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& r }; } -__forceinline__ __device__ void getRect(const float2 p, int2 ext_rect, uint2& rect_min, uint2& rect_max, dim3 grid) -{ +__forceinline__ __device__ void getRect(const float2 p, + const int2 ext_rect, + uint2& rect_min, + uint2& rect_max, + const dim3 grid) { rect_min = { min(grid.x, max((int)0, (int)((p.x - ext_rect.x) / BLOCK_X))), min(grid.y, max((int)0, (int)((p.y - ext_rect.y) / BLOCK_Y))) @@ -67,8 +62,8 @@ __forceinline__ __device__ void getRect(const float2 p, int2 ext_rect, uint2& re } -__forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix) -{ +__forceinline__ __device__ float3 transformPoint4x3(const float3& p, + const float* matrix) { float3 transformed = { matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], @@ -77,8 +72,8 @@ __forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float return transformed; } -__forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix) -{ +__forceinline__ __device__ float4 transformPoint4x4(const float3& p, + const float* matrix) { float4 transformed = { matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], @@ -88,8 +83,8 @@ __forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float return transformed; } -__forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix) -{ +__forceinline__ __device__ float3 transformVec4x3(const float3& p, + const float* matrix) { float3 transformed = { matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z, matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z, @@ -98,8 +93,8 @@ __forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* return transformed; } -__forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, const float* matrix) -{ +__forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, + const float* matrix) { float3 transformed = { matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z, matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z, @@ -108,16 +103,16 @@ __forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, cons return transformed; } -__forceinline__ __device__ float dnormvdz(float3 v, float3 dv) -{ +__forceinline__ __device__ float dnormvdz(float3 v, + float3 dv) { float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; return dnormvdz; } -__forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) -{ +__forceinline__ __device__ float3 dnormvdv(float3 v, + float3 dv) { float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); @@ -128,8 +123,8 @@ __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) return dnormvdv; } -__forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv) -{ +__forceinline__ __device__ float4 dnormvdv(float4 v, + float4 dv) { float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w; float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); @@ -143,18 +138,17 @@ __forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv) return dnormvdv; } -__forceinline__ __device__ float sigmoid(float x) +__forceinline__ __device__ float sigmoid(const float x) { return 1.0f / (1.0f + expf(-x)); } -__forceinline__ __device__ bool in_frustum(int idx, - const float* orig_points, - const float* viewmatrix, - const float* projmatrix, - bool prefiltered, - float3& p_view) -{ +__forceinline__ __device__ bool in_frustum(const int idx, + const float* orig_points, + const float* viewmatrix, + const float* projmatrix, + const bool prefiltered, + float3& p_view) { float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; // Bring points to screen space @@ -184,4 +178,4 @@ throw std::runtime_error(cudaGetErrorString(ret)); \ } \ } -#endif +#endif // CUDA_RASTERIZER_AUXILIARY_H diff --git a/cuda_rasterizer/backward.cu b/cuda_rasterizer/backward.cu index 006b1ef7..8088fa42 100644 --- a/cuda_rasterizer/backward.cu +++ b/cuda_rasterizer/backward.cu @@ -1,14 +1,3 @@ -/* - * Copyright (C) 2023, Inria - * GRAPHDECO research group, https://team.inria.fr/graphdeco - * All rights reserved. - * - * This software is free for non-commercial, research and evaluation use - * under the terms of the LICENSE.md file. - * - * For inquiries contact george.drettakis@inria.fr - */ - #include "backward.h" #include "auxiliary.h" #include @@ -20,8 +9,16 @@ __device__ __forceinline__ float sq(float x) { return x * x; } // Backward pass for conversion of spherical harmonics to RGB for // each Gaussian. -__device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs) -{ +__device__ void computeColorFromSH(int idx, + int deg, + int max_coeffs, + const glm::vec3* means, + glm::vec3 campos, + const float* shs, + const bool* clamped, + const glm::vec3* dL_dcolor, + glm::vec3* dL_dmeans, + glm::vec3* dL_dshs) { // Compute intermediate values, as it is done during forward glm::vec3 pos = means[idx]; glm::vec3 dir_orig = pos - campos; @@ -144,21 +141,22 @@ __device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm:: // Backward version of INVERSE 2D covariance matrix computation // (due to length launched as separate kernel before other // backward steps contained in preprocess) -__global__ void computeCov2DCUDA(int P, - const float3* means, - const int* radii, - const float* cov3Ds, - const float h_x, float h_y, - const float tan_fovx, float tan_fovy, - const float* view_matrix, - const float* opacities, - const float* dL_dconics, - float* dL_dopacity, - const float* dL_dinvdepth, - float3* dL_dmeans, - float* dL_dcov, - bool antialiasing) -{ +__global__ void computeCov2DCUDA(const int P, + const float3* means, + const int* radii, + const float* cov3Ds, + const float h_x, + const float h_y, + const float tan_fovx, + const float tan_fovy, + const float* view_matrix, + const float* opacities, + const float* dL_dconics, + float* dL_dopacity, + const float* dL_dinvdepth, + float3* dL_dmeans, + float* dL_dcov, + bool antialiasing) { auto idx = cg::this_grid().thread_rank(); if (idx >= P || !(radii[idx] > 0)) return; @@ -327,8 +325,13 @@ __global__ void computeCov2DCUDA(int P, // Backward pass for the conversion of scale and rotation to a // 3D covariance matrix for each Gaussian. -__device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const glm::vec4 rot, const float* dL_dcov3Ds, glm::vec3* dL_dscales, glm::vec4* dL_drots) -{ +__device__ void computeCov3D(int idx, + const glm::vec3 scale, + float mod, + const glm::vec4 rot, + const float* dL_dcov3Ds, + glm::vec3* dL_dscales, + glm::vec4* dL_drots) { // Recompute (intermediate) results for the 3D covariance computation. glm::vec4 q = rot;// / glm::length(rot); float r = q.x; @@ -396,26 +399,26 @@ __device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const gl // for the covariance computation and inversion // (those are handled by a previous kernel call) template -__global__ void preprocessCUDA( - int P, int D, int M, - const float3* means, - const int* radii, - const float* shs, - const bool* clamped, - const glm::vec3* scales, - const glm::vec4* rotations, - const float scale_modifier, - const float* proj, - const glm::vec3* campos, - const float3* dL_dmean2D, - glm::vec3* dL_dmeans, - float* dL_dcolor, - float* dL_dcov3D, - float* dL_dsh, - glm::vec3* dL_dscale, - glm::vec4* dL_drot, - float* dL_dopacity) -{ +__global__ void preprocessCUDA(int P, + int D, + int M, + const float3* means, + const int* radii, + const float* shs, + const bool* clamped, + const glm::vec3* scales, + const glm::vec4* rotations, + const float scale_modifier, + const float* proj, + const glm::vec3* campos, + const float3* dL_dmean2D, + glm::vec3* dL_dmeans, + float* dL_dcolor, + float* dL_dcov3D, + float* dL_dsh, + glm::vec3* dL_dscale, + glm::vec4* dL_drot, + float* dL_dopacity) { auto idx = cg::this_grid().thread_rank(); if (idx >= P || !(radii[idx] > 0)) return; @@ -450,27 +453,24 @@ __global__ void preprocessCUDA( // Backward version of the rendering procedure. template -__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) -renderCUDA( - const uint2* __restrict__ ranges, - const uint32_t* __restrict__ point_list, - int W, int H, - const float* __restrict__ bg_color, - const float2* __restrict__ points_xy_image, - const float4* __restrict__ conic_opacity, - const float* __restrict__ colors, - const float* __restrict__ depths, - const float* __restrict__ final_Ts, - const uint32_t* __restrict__ n_contrib, - const float* __restrict__ dL_dpixels, - const float* __restrict__ dL_invdepths, - float3* __restrict__ dL_dmean2D, - float4* __restrict__ dL_dconic2D, - float* __restrict__ dL_dopacity, - float* __restrict__ dL_dcolors, - float* __restrict__ dL_dinvdepths -) -{ +__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) renderCUDA(const uint2* __restrict__ ranges, + const uint32_t* __restrict__ point_list, + int W, + int H, + const float* __restrict__ bg_color, + const float2* __restrict__ points_xy_image, + const float4* __restrict__ conic_opacity, + const float* __restrict__ colors, + const float* __restrict__ depths, + const float* __restrict__ final_Ts, + const uint32_t* __restrict__ n_contrib, + const float* __restrict__ dL_dpixels, + const float* __restrict__ dL_invdepths, + float3* __restrict__ dL_dmean2D, + float4* __restrict__ dL_dconic2D, + float* __restrict__ dL_dopacity, + float* __restrict__ dL_dcolors, + float* __restrict__ dL_dinvdepths) { // We rasterize again. Compute necessary block info. auto block = cg::this_thread_block(); const uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X; @@ -528,8 +528,7 @@ renderCUDA( const float ddely_dy = 0.5 * H; // Traverse all Gaussians - for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) - { + for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) { // Load auxiliary data into shared memory, start in the BACK // and load them in revers order. block.sync(); @@ -549,8 +548,7 @@ renderCUDA( block.sync(); // Iterate over Gaussians - for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) - { + for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) { // Keep track of current Gaussian ID. Skip, if this one // is behind the last contributor for this pixel. contributor--; @@ -594,13 +592,12 @@ renderCUDA( } // Propagate gradients from inverse depth to alphaas and // per Gaussian inverse depths - if (dL_dinvdepths) - { - const float invd = 1.f / collected_depths[j]; - accum_invdepth_rec = last_alpha * last_invdepth + (1.f - last_alpha) * accum_invdepth_rec; - last_invdepth = invd; - dL_dalpha += (invd - accum_invdepth_rec) * dL_invdepth; - atomicAdd(&(dL_dinvdepths[global_id]), dchannel_dcolor * dL_invdepth); + if (dL_dinvdepths) { + const float invd = 1.f / collected_depths[j]; + accum_invdepth_rec = last_alpha * last_invdepth + (1.f - last_alpha) * accum_invdepth_rec; + last_invdepth = invd; + dL_dalpha += (invd - accum_invdepth_rec) * dL_invdepth; + atomicAdd(&(dL_dinvdepths[global_id]), dchannel_dcolor * dL_invdepth); } dL_dalpha *= T; @@ -637,34 +634,36 @@ renderCUDA( } } -void BACKWARD::preprocess( - int P, int D, int M, - const float3* means3D, - const int* radii, - const float* shs, - const bool* clamped, - const float* opacities, - const glm::vec3* scales, - const glm::vec4* rotations, - const float scale_modifier, - const float* cov3Ds, - const float* viewmatrix, - const float* projmatrix, - const float focal_x, float focal_y, - const float tan_fovx, float tan_fovy, - const glm::vec3* campos, - const float3* dL_dmean2D, - const float* dL_dconic, - const float* dL_dinvdepth, - float* dL_dopacity, - glm::vec3* dL_dmean3D, - float* dL_dcolor, - float* dL_dcov3D, - float* dL_dsh, - glm::vec3* dL_dscale, - glm::vec4* dL_drot, - bool antialiasing) -{ +void BACKWARD::preprocess(const int P, + const int D, + const int M, + const float3* means3D, + const int* radii, + const float* shs, + const bool* clamped, + const float* opacities, + const glm::vec3* scales, + const glm::vec4* rotations, + const float scale_modifier, + const float* cov3Ds, + const float* viewmatrix, + const float* projmatrix, + const float focal_x, + const float focal_y, + const float tan_fovx, + const float tan_fovy, + const glm::vec3* campos, + const float3* dL_dmean2D, + const float* dL_dconic, + const float* dL_dinvdepth, + float* dL_dopacity, + glm::vec3* dL_dmean3D, + float* dL_dcolor, + float* dL_dcov3D, + float* dL_dsh, + glm::vec3* dL_dscale, + glm::vec4* dL_drot, + bool antialiasing) { // Propagate gradients for the path of 2D conic matrix computation. // Somewhat long, thus it is its own kernel rather than being part of // "preprocess". When done, loss gradient w.r.t. 3D means has been @@ -711,26 +710,26 @@ void BACKWARD::preprocess( dL_dopacity); } -void BACKWARD::render( - const dim3 grid, const dim3 block, - const uint2* ranges, - const uint32_t* point_list, - int W, int H, - const float* bg_color, - const float2* means2D, - const float4* conic_opacity, - const float* colors, - const float* depths, - const float* final_Ts, - const uint32_t* n_contrib, - const float* dL_dpixels, - const float* dL_invdepths, - float3* dL_dmean2D, - float4* dL_dconic2D, - float* dL_dopacity, - float* dL_dcolors, - float* dL_dinvdepths) -{ +void BACKWARD::render(const dim3 grid, + const dim3 block, + const uint2* ranges, + const uint32_t* point_list, + int W, + int H, + const float* bg_color, + const float2* means2D, + const float4* conic_opacity, + const float* colors, + const float* depths, + const float* final_Ts, + const uint32_t* n_contrib, + const float* dL_dpixels, + const float* dL_invdepths, + float3* dL_dmean2D, + float4* dL_dconic2D, + float* dL_dopacity, + float* dL_dcolors, + float* dL_dinvdepths) { renderCUDA << > >( ranges, point_list, @@ -748,6 +747,5 @@ void BACKWARD::render( dL_dconic2D, dL_dopacity, dL_dcolors, - dL_dinvdepths - ); + dL_dinvdepths); } diff --git a/cuda_rasterizer/backward.h b/cuda_rasterizer/backward.h index 4d02560f..464f2070 100644 --- a/cuda_rasterizer/backward.h +++ b/cuda_rasterizer/backward.h @@ -1,16 +1,5 @@ -/* - * Copyright (C) 2023, Inria - * GRAPHDECO research group, https://team.inria.fr/graphdeco - * All rights reserved. - * - * This software is free for non-commercial, research and evaluation use - * under the terms of the LICENSE.md file. - * - * For inquiries contact george.drettakis@inria.fr - */ - -#ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED -#define CUDA_RASTERIZER_BACKWARD_H_INCLUDED +#ifndef CUDA_RASTERIZER_BACKWARD_H +#define CUDA_RASTERIZER_BACKWARD_H #include #include "cuda_runtime.h" @@ -20,53 +9,57 @@ namespace BACKWARD { - void render( - const dim3 grid, dim3 block, - const uint2* ranges, - const uint32_t* point_list, - int W, int H, - const float* bg_color, - const float2* means2D, - const float4* conic_opacity, - const float* colors, - const float* depths, - const float* final_Ts, - const uint32_t* n_contrib, - const float* dL_dpixels, - const float* dL_invdepths, - float3* dL_dmean2D, - float4* dL_dconic2D, - float* dL_dopacity, - float* dL_dcolors, - float* dL_dinvdepths); + void render(const dim3 grid, + dim3 block, + const uint2* ranges, + const uint32_t* point_list, + const int W, + const int H, + const float* bg_color, + const float2* means2D, + const float4* conic_opacity, + const float* colors, + const float* depths, + const float* final_Ts, + const uint32_t* n_contrib, + const float* dL_dpixels, + const float* dL_invdepths, + float3* dL_dmean2D, + float4* dL_dconic2D, + float* dL_dopacity, + float* dL_dcolors, + float* dL_dinvdepths); - void preprocess( - int P, int D, int M, - const float3* means, - const int* radii, - const float* shs, - const bool* clamped, - const float* opacities, - const glm::vec3* scales, - const glm::vec4* rotations, - const float scale_modifier, - const float* cov3Ds, - const float* view, - const float* proj, - const float focal_x, float focal_y, - const float tan_fovx, float tan_fovy, - const glm::vec3* campos, - const float3* dL_dmean2D, - const float* dL_dconics, - const float* dL_dinvdepth, - float* dL_dopacity, - glm::vec3* dL_dmeans, - float* dL_dcolor, - float* dL_dcov3D, - float* dL_dsh, - glm::vec3* dL_dscale, - glm::vec4* dL_drot, - bool antialiasing); + void preprocess(const int P, + const int D, + const int M, + const float3* means, + const int* radii, + const float* shs, + const bool* clamped, + const float* opacities, + const glm::vec3* scales, + const glm::vec4* rotations, + const float scale_modifier, + const float* cov3Ds, + const float* view, + const float* proj, + const float focal_x, + const float focal_y, + const float tan_fovx, + const float tan_fovy, + const glm::vec3* campos, + const float3* dL_dmean2D, + const float* dL_dconics, + const float* dL_dinvdepth, + float* dL_dopacity, + glm::vec3* dL_dmeans, + float* dL_dcolor, + float* dL_dcov3D, + float* dL_dsh, + glm::vec3* dL_dscale, + glm::vec4* dL_drot, + bool antialiasing); } -#endif +#endif // CUDA_RASTERIZER_BACKWARD_H diff --git a/cuda_rasterizer/config.h b/cuda_rasterizer/config.h index 2a912fb3..6db8bd1c 100644 --- a/cuda_rasterizer/config.h +++ b/cuda_rasterizer/config.h @@ -1,18 +1,7 @@ -/* - * Copyright (C) 2023, Inria - * GRAPHDECO research group, https://team.inria.fr/graphdeco - * All rights reserved. - * - * This software is free for non-commercial, research and evaluation use - * under the terms of the LICENSE.md file. - * - * For inquiries contact george.drettakis@inria.fr - */ +#ifndef CUDA_RASTERIZER_CONFIG_H +#define CUDA_RASTERIZER_CONFIG_H -#ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED -#define CUDA_RASTERIZER_CONFIG_H_INCLUDED - -#define NUM_CHANNELS 3 // Default 3, RGB +#define NUM_CHANNELS 3 // Default 3, RGB #define BLOCK_X 16 #define BLOCK_Y 16 diff --git a/cuda_rasterizer/forward.cu b/cuda_rasterizer/forward.cu index c5e01ddc..d4be47a1 100644 --- a/cuda_rasterizer/forward.cu +++ b/cuda_rasterizer/forward.cu @@ -1,14 +1,3 @@ -/* - * Copyright (C) 2023, Inria - * GRAPHDECO research group, https://team.inria.fr/graphdeco - * All rights reserved. - * - * This software is free for non-commercial, research and evaluation use - * under the terms of the LICENSE.md file. - * - * For inquiries contact george.drettakis@inria.fr - */ - #include "forward.h" #include "auxiliary.h" #include @@ -17,12 +6,17 @@ namespace cg = cooperative_groups; // Forward method for converting the input spherical harmonics // coefficients of each Gaussian to a simple RGB color. -__device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped) -{ +__device__ glm::vec3 computeColorFromSH(int idx, + int deg, + int max_coeffs, + const glm::vec3* means, + glm::vec3 campos, + const float* shs, + bool* clamped) { // The implementation is loosely based on code for // "Differentiable Point-Based Radiance Fields for // Efficient View Synthesis" by Zhang et al. (2022) - glm::vec3 pos = means[idx]; + const glm::vec3 pos = means[idx]; glm::vec3 dir = pos - campos; dir = dir / glm::length(dir); @@ -31,15 +25,15 @@ __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const if (deg > 0) { - float x = dir.x; - float y = dir.y; - float z = dir.z; + const float x = dir.x; + const float y = dir.y; + const float z = dir.z; result = result - SH_C1 * y * sh[1] + SH_C1 * z * sh[2] - SH_C1 * x * sh[3]; if (deg > 1) { - float xx = x * x, yy = y * y, zz = z * z; - float xy = x * y, yz = y * z, xz = x * z; + const float xx = x * x, yy = y * y, zz = z * z; + const float xy = x * y, yz = y * z, xz = x * z; result = result + SH_C2[0] * xy * sh[4] + SH_C2[1] * yz * sh[5] + @@ -71,8 +65,13 @@ __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const } // Forward version of 2D covariance matrix computation -__device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float* cov3D, const float* viewmatrix) -{ +__device__ float3 computeCov2D(const float3& mean, + float focal_x, + float focal_y, + float tan_fovx, + float tan_fovy, + const float* cov3D, + const float* viewmatrix) { // The following models the steps outlined by equations 29 // and 31 in "EWA Splatting" (Zwicker et al., 2002). // Additionally considers aspect / scaling of viewport. @@ -111,8 +110,10 @@ __device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, // Forward method for converting scale and rotation properties of each // Gaussian to a 3D covariance matrix in world space. Also takes care // of quaternion normalization. -__device__ void computeCov3D(const glm::vec3 scale, float mod, const glm::vec4 rot, float* cov3D) -{ +__device__ void computeCov3D(const glm::vec3 scale, + float mod, + const glm::vec4 rot, + float* cov3D) { // Create scaling matrix glm::mat3 S = glm::mat3(1.0f); S[0][0] = mod * scale.x; @@ -149,33 +150,37 @@ __device__ void computeCov3D(const glm::vec3 scale, float mod, const glm::vec4 r // Perform initial steps for each Gaussian prior to rasterization. template -__global__ void preprocessCUDA(int P, int D, int M, - const float* orig_points, - const glm::vec3* scales, - const float scale_modifier, - const glm::vec4* rotations, - const float* opacities, - const float* shs, - bool* clamped, - const float* cov3D_precomp, - const float* colors_precomp, - const float* viewmatrix, - const float* projmatrix, - const glm::vec3* cam_pos, - const int W, int H, - const float tan_fovx, float tan_fovy, - const float focal_x, float focal_y, - int* radii, - float2* points_xy_image, - float* depths, - float* cov3Ds, - float* rgb, - float4* conic_opacity, - const dim3 grid, - uint32_t* tiles_touched, - bool prefiltered, - bool antialiasing) -{ +__global__ void preprocessCUDA(const int P, + const int D, + const int M, + const float* orig_points, + const glm::vec3* scales, + const float scale_modifier, + const glm::vec4* rotations, + const float* opacities, + const float* shs, + bool* clamped, + const float* cov3D_precomp, + const float* colors_precomp, + const float* viewmatrix, + const float* projmatrix, + const glm::vec3* cam_pos, + const int W, + const int H, + const float tan_fovx, + const float tan_fovy, + const float focal_x, + const float focal_y, + int* radii, + float2* points_xy_image, + float* depths, + float* cov3Ds, + float* rgb, + float4* conic_opacity, + const dim3 grid, + uint32_t* tiles_touched, + bool prefiltered, + bool antialiasing) { auto idx = cg::this_grid().thread_rank(); if (idx >= P) return; @@ -268,25 +273,84 @@ __global__ void preprocessCUDA(int P, int D, int M, tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x); } +void FORWARD::preprocess(const int P, + const int D, + const int M, + const float* means3D, + const glm::vec3* scales, + const float scale_modifier, + const glm::vec4* rotations, + const float* opacities, + const float* shs, + bool* clamped, + const float* cov3D_precomp, + const float* colors_precomp, + const float* viewmatrix, + const float* projmatrix, + const glm::vec3* cam_pos, + const int W, + const int H, + const float focal_x, + const float focal_y, + const float tan_fovx, + const float tan_fovy, + int* radii, + float2* means2D, + float* depths, + float* cov3Ds, + float* rgb, + float4* conic_opacity, + const dim3 grid, + uint32_t* tiles_touched, + bool prefiltered, + bool antialiasing) { + preprocessCUDA << <(P + 255) / 256, 256 >> > ( + P, D, M, + means3D, + scales, + scale_modifier, + rotations, + opacities, + shs, + clamped, + cov3D_precomp, + colors_precomp, + viewmatrix, + projmatrix, + cam_pos, + W, H, + tan_fovx, tan_fovy, + focal_x, focal_y, + radii, + means2D, + depths, + cov3Ds, + rgb, + conic_opacity, + grid, + tiles_touched, + prefiltered, + antialiasing + ); +} + // Main rasterization method. Collaboratively works on one tile per // block, each thread treats one pixel. Alternates between fetching // and rasterizing data. template -__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) -renderCUDA( - const uint2* __restrict__ ranges, - const uint32_t* __restrict__ point_list, - int W, int H, - const float2* __restrict__ points_xy_image, - const float* __restrict__ features, - const float4* __restrict__ conic_opacity, - float* __restrict__ final_T, - uint32_t* __restrict__ n_contrib, - const float* __restrict__ bg_color, - float* __restrict__ out_color, - const float* __restrict__ depths, - float* __restrict__ invdepth) -{ +__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) renderCUDA(const uint2* __restrict__ ranges, + const uint32_t* __restrict__ point_list, + int W, + int H, + const float2* __restrict__ points_xy_image, + const float* __restrict__ features, + const float4* __restrict__ conic_opacity, + float* __restrict__ final_T, + uint32_t* __restrict__ n_contrib, + const float* __restrict__ bg_color, + float* __restrict__ out_color, + const float* __restrict__ depths, + float* __restrict__ invdepth) { // Identify current tile and associated min/max pixel range. auto block = cg::this_thread_block(); uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X; @@ -396,21 +460,21 @@ renderCUDA( } } -void FORWARD::render( - const dim3 grid, dim3 block, - const uint2* ranges, - const uint32_t* point_list, - int W, int H, - const float2* means2D, - const float* colors, - const float4* conic_opacity, - float* final_T, - uint32_t* n_contrib, - const float* bg_color, - float* out_color, - float* depths, - float* depth) -{ +void FORWARD::render(const dim3 grid, + const dim3 block, + const uint2* ranges, + const uint32_t* point_list, + const int W, + const int H, + const float2* means2D, + const float* colors, + const float4* conic_opacity, + float* final_T, + uint32_t* n_contrib, + const float* bg_color, + float* out_color, + float* depths, + float* depth) { renderCUDA << > > ( ranges, point_list, @@ -425,60 +489,3 @@ void FORWARD::render( depths, depth); } - -void FORWARD::preprocess(int P, int D, int M, - const float* means3D, - const glm::vec3* scales, - const float scale_modifier, - const glm::vec4* rotations, - const float* opacities, - const float* shs, - bool* clamped, - const float* cov3D_precomp, - const float* colors_precomp, - const float* viewmatrix, - const float* projmatrix, - const glm::vec3* cam_pos, - const int W, int H, - const float focal_x, float focal_y, - const float tan_fovx, float tan_fovy, - int* radii, - float2* means2D, - float* depths, - float* cov3Ds, - float* rgb, - float4* conic_opacity, - const dim3 grid, - uint32_t* tiles_touched, - bool prefiltered, - bool antialiasing) -{ - preprocessCUDA << <(P + 255) / 256, 256 >> > ( - P, D, M, - means3D, - scales, - scale_modifier, - rotations, - opacities, - shs, - clamped, - cov3D_precomp, - colors_precomp, - viewmatrix, - projmatrix, - cam_pos, - W, H, - tan_fovx, tan_fovy, - focal_x, focal_y, - radii, - means2D, - depths, - cov3Ds, - rgb, - conic_opacity, - grid, - tiles_touched, - prefiltered, - antialiasing - ); -} diff --git a/cuda_rasterizer/forward.h b/cuda_rasterizer/forward.h index 5b18005f..d8de8b8d 100644 --- a/cuda_rasterizer/forward.h +++ b/cuda_rasterizer/forward.h @@ -1,16 +1,5 @@ -/* - * Copyright (C) 2023, Inria - * GRAPHDECO research group, https://team.inria.fr/graphdeco - * All rights reserved. - * - * This software is free for non-commercial, research and evaluation use - * under the terms of the LICENSE.md file. - * - * For inquiries contact george.drettakis@inria.fr - */ - -#ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED -#define CUDA_RASTERIZER_FORWARD_H_INCLUDED +#ifndef CUDA_RASTERIZER_FORWARD_H +#define CUDA_RASTERIZER_FORWARD_H #include #include "cuda_runtime.h" @@ -21,48 +10,54 @@ namespace FORWARD { // Perform initial steps for each Gaussian prior to rasterization. - void preprocess(int P, int D, int M, - const float* orig_points, - const glm::vec3* scales, - const float scale_modifier, - const glm::vec4* rotations, - const float* opacities, - const float* shs, - bool* clamped, - const float* cov3D_precomp, - const float* colors_precomp, - const float* viewmatrix, - const float* projmatrix, - const glm::vec3* cam_pos, - const int W, int H, - const float focal_x, float focal_y, - const float tan_fovx, float tan_fovy, - int* radii, - float2* points_xy_image, - float* depths, - float* cov3Ds, - float* colors, - float4* conic_opacity, - const dim3 grid, - uint32_t* tiles_touched, - bool prefiltered, - bool antialiasing); + void preprocess(const int P, + const int D, + const int M, + const float* orig_points, + const glm::vec3* scales, + const float scale_modifier, + const glm::vec4* rotations, + const float* opacities, + const float* shs, + bool* clamped, + const float* cov3D_precomp, + const float* colors_precomp, + const float* viewmatrix, + const float* projmatrix, + const glm::vec3* cam_pos, + const int W, + const int H, + const float focal_x, + const float focal_y, + const float tan_fovx, + const float tan_fovy, + int* radii, + float2* points_xy_image, + float* depths, + float* cov3Ds, + float* colors, + float4* conic_opacity, + const dim3 grid, + uint32_t* tiles_touched, + bool prefiltered, + bool antialiasing); // Main rasterization method. - void render( - const dim3 grid, dim3 block, - const uint2* ranges, - const uint32_t* point_list, - int W, int H, - const float2* points_xy_image, - const float* features, - const float4* conic_opacity, - float* final_T, - uint32_t* n_contrib, - const float* bg_color, - float* out_color, - float* depths, - float* depth); + void render(const dim3 grid, + const dim3 block, + const uint2* ranges, + const uint32_t* point_list, + const int W, + const int H, + const float2* points_xy_image, + const float* features, + const float4* conic_opacity, + float* final_T, + uint32_t* n_contrib, + const float* bg_color, + float* out_color, + float* depths, + float* depth); } diff --git a/cuda_rasterizer/rasterizer.h b/cuda_rasterizer/rasterizer.h index 64afdee9..7baffa5e 100644 --- a/cuda_rasterizer/rasterizer.h +++ b/cuda_rasterizer/rasterizer.h @@ -1,16 +1,5 @@ -/* - * Copyright (C) 2023, Inria - * GRAPHDECO research group, https://team.inria.fr/graphdeco - * All rights reserved. - * - * This software is free for non-commercial, research and evaluation use - * under the terms of the LICENSE.md file. - * - * For inquiries contact george.drettakis@inria.fr - */ - -#ifndef CUDA_RASTERIZER_H_INCLUDED -#define CUDA_RASTERIZER_H_INCLUDED +#ifndef CUDA_RASTERIZER_H +#define CUDA_RASTERIZER_H #include #include @@ -32,9 +21,12 @@ namespace CudaRasterizer std::function geometryBuffer, std::function binningBuffer, std::function imageBuffer, - const int P, int D, int M, + const int P, + const int D, + const int M, const float* background, - const int width, int height, + const int width, + const int height, const float* means3D, const float* shs, const float* colors_precomp, @@ -46,7 +38,8 @@ namespace CudaRasterizer const float* viewmatrix, const float* projmatrix, const float* cam_pos, - const float tan_fovx, float tan_fovy, + const float tan_fovx, + const float tan_fovy, const bool prefiltered, float* out_color, float* depth, @@ -55,9 +48,13 @@ namespace CudaRasterizer bool debug = false); static void backward( - const int P, int D, int M, int R, + const int P, + const int D, + const int M, + const int R, const float* background, - const int width, int height, + const int width, + const int height, const float* means3D, const float* shs, const float* colors_precomp, @@ -69,7 +66,8 @@ namespace CudaRasterizer const float* viewmatrix, const float* projmatrix, const float* campos, - const float tan_fovx, float tan_fovy, + const float tan_fovx, + const float tan_fovy, const int* radii, char* geom_buffer, char* binning_buffer, @@ -91,4 +89,4 @@ namespace CudaRasterizer }; }; -#endif +#endif // CUDA_RASTERIZER_H \ No newline at end of file diff --git a/cuda_rasterizer/rasterizer_impl.cu b/cuda_rasterizer/rasterizer_impl.cu index 41c4ed3a..b45f6160 100644 --- a/cuda_rasterizer/rasterizer_impl.cu +++ b/cuda_rasterizer/rasterizer_impl.cu @@ -1,14 +1,3 @@ -/* - * Copyright (C) 2023, Inria - * GRAPHDECO research group, https://team.inria.fr/graphdeco - * All rights reserved. - * - * This software is free for non-commercial, research and evaluation use - * under the terms of the LICENSE.md file. - * - * For inquiries contact george.drettakis@inria.fr - */ - #include "rasterizer_impl.h" #include #include @@ -32,12 +21,10 @@ namespace cg = cooperative_groups; // Helper function to find the next-highest bit of the MSB // on the CPU. -uint32_t getHigherMsb(uint32_t n) -{ +uint32_t getHigherMsb(uint32_t n) { uint32_t msb = sizeof(n) * 4; uint32_t step = msb; - while (step > 1) - { + while (step > 1) { step /= 2; if (n >> msb) msb += step; @@ -52,11 +39,10 @@ uint32_t getHigherMsb(uint32_t n) // Wrapper method to call auxiliary coarse frustum containment test. // Mark all Gaussians that pass it. __global__ void checkFrustum(int P, - const float* orig_points, - const float* viewmatrix, - const float* projmatrix, - bool* present) -{ + const float* orig_points, + const float* viewmatrix, + const float* projmatrix, + bool* present) { auto idx = cg::this_grid().thread_rank(); if (idx >= P) return; @@ -67,23 +53,20 @@ __global__ void checkFrustum(int P, // Generates one key/value pair for all Gaussian / tile overlaps. // Run once per Gaussian (1:N mapping). -__global__ void duplicateWithKeys( - int P, - const float2* points_xy, - const float* depths, - const uint32_t* offsets, - uint64_t* gaussian_keys_unsorted, - uint32_t* gaussian_values_unsorted, - int* radii, - dim3 grid) -{ +__global__ void duplicateWithKeys(int P, + const float2* points_xy, + const float* depths, + const uint32_t* offsets, + uint64_t* gaussian_keys_unsorted, + uint32_t* gaussian_values_unsorted, + int* radii, + dim3 grid) { auto idx = cg::this_grid().thread_rank(); if (idx >= P) return; // Generate no key/value pair for invisible Gaussians - if (radii[idx] > 0) - { + if (radii[idx] > 0) { // Find this Gaussian's offset in buffer for writing keys/values. uint32_t off = (idx == 0) ? 0 : offsets[idx - 1]; uint2 rect_min, rect_max; @@ -95,10 +78,8 @@ __global__ void duplicateWithKeys( // and the value is the ID of the Gaussian. Sorting the values // with this key yields Gaussian IDs in a list, such that they // are first sorted by tile and then by depth. - for (int y = rect_min.y; y < rect_max.y; y++) - { - for (int x = rect_min.x; x < rect_max.x; x++) - { + for (int y = rect_min.y; y < rect_max.y; y++) { + for (int x = rect_min.x; x < rect_max.x; x++) { uint64_t key = y * grid.x + x; key <<= 32; key |= *((uint32_t*)&depths[idx]); @@ -113,8 +94,9 @@ __global__ void duplicateWithKeys( // Check keys to see if it is at the start/end of one tile's range in // the full sorted list. If yes, write start/end of this tile. // Run once per instanced (duplicated) Gaussian ID. -__global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* ranges) -{ +__global__ void identifyTileRanges(int L, + uint64_t* point_list_keys, + uint2* ranges) { auto idx = cg::this_grid().thread_rank(); if (idx >= L) return; @@ -124,11 +106,9 @@ __global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* rang uint32_t currtile = key >> 32; if (idx == 0) ranges[currtile].x = 0; - else - { + else { uint32_t prevtile = point_list_keys[idx - 1] >> 32; - if (currtile != prevtile) - { + if (currtile != prevtile) { ranges[prevtile].y = idx; ranges[currtile].x = idx; } @@ -138,13 +118,11 @@ __global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* rang } // Mark Gaussians as visible/invisible, based on view frustum testing -void CudaRasterizer::Rasterizer::markVisible( - int P, - float* means3D, - float* viewmatrix, - float* projmatrix, - bool* present) -{ +void CudaRasterizer::Rasterizer::markVisible(int P, + float* means3D, + float* viewmatrix, + float* projmatrix, + bool* present) { checkFrustum << <(P + 255) / 256, 256 >> > ( P, means3D, @@ -152,8 +130,8 @@ void CudaRasterizer::Rasterizer::markVisible( present); } -CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& chunk, size_t P) -{ +CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& chunk, + size_t P) { GeometryState geom; obtain(chunk, geom.depths, P, 128); obtain(chunk, geom.clamped, P * 3, 128); @@ -169,8 +147,8 @@ CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& ch return geom; } -CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, size_t N) -{ +CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, + size_t N) { ImageState img; obtain(chunk, img.accum_alpha, N, 128); obtain(chunk, img.n_contrib, N, 128); @@ -178,8 +156,8 @@ CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, s return img; } -CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chunk, size_t P) -{ +CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chunk, + size_t P) { BinningState binning; obtain(chunk, binning.point_list, P, 128); obtain(chunk, binning.point_list_unsorted, P, 128); @@ -195,32 +173,34 @@ CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chun // Forward rendering procedure for differentiable rasterization // of Gaussians. -int CudaRasterizer::Rasterizer::forward( - std::function geometryBuffer, - std::function binningBuffer, - std::function imageBuffer, - const int P, int D, int M, - const float* background, - const int width, int height, - const float* means3D, - const float* shs, - const float* colors_precomp, - const float* opacities, - const float* scales, - const float scale_modifier, - const float* rotations, - const float* cov3D_precomp, - const float* viewmatrix, - const float* projmatrix, - const float* cam_pos, - const float tan_fovx, float tan_fovy, - const bool prefiltered, - float* out_color, - float* depth, - bool antialiasing, - int* radii, - bool debug) -{ +int CudaRasterizer::Rasterizer::forward(std::function geometryBuffer, + std::function binningBuffer, + std::function imageBuffer, + const int P, + const int D, + const int M, + const float* background, + const int width, + const int height, + const float* means3D, + const float* shs, + const float* colors_precomp, + const float* opacities, + const float* scales, + const float scale_modifier, + const float* rotations, + const float* cov3D_precomp, + const float* viewmatrix, + const float* projmatrix, + const float* cam_pos, + const float tan_fovx, + const float tan_fovy, + const bool prefiltered, + float* out_color, + float* depth, + bool antialiasing, + int* radii, + bool debug) { const float focal_y = height / (2.0f * tan_fovy); const float focal_x = width / (2.0f * tan_fovx); @@ -241,14 +221,15 @@ int CudaRasterizer::Rasterizer::forward( char* img_chunkptr = imageBuffer(img_chunk_size); ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height); - if (NUM_CHANNELS != 3 && colors_precomp == nullptr) - { + if (NUM_CHANNELS != 3 && colors_precomp == nullptr) { throw std::runtime_error("For non-RGB, provide precomputed Gaussian colors!"); } // Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs to RGB) CHECK_CUDA(FORWARD::preprocess( - P, D, M, + P, + D, + M, means3D, (glm::vec3*)scales, scale_modifier, @@ -258,11 +239,15 @@ int CudaRasterizer::Rasterizer::forward( geomState.clamped, cov3D_precomp, colors_precomp, - viewmatrix, projmatrix, + viewmatrix, + projmatrix, (glm::vec3*)cam_pos, - width, height, - focal_x, focal_y, - tan_fovx, tan_fovy, + width, + height, + focal_x, + focal_y, + tan_fovx, + tan_fovy, radii, geomState.means2D, geomState.depths, @@ -342,47 +327,49 @@ int CudaRasterizer::Rasterizer::forward( // Produce necessary gradients for optimization, corresponding // to forward render pass -void CudaRasterizer::Rasterizer::backward( - const int P, int D, int M, int R, - const float* background, - const int width, int height, - const float* means3D, - const float* shs, - const float* colors_precomp, - const float* opacities, - const float* scales, - const float scale_modifier, - const float* rotations, - const float* cov3D_precomp, - const float* viewmatrix, - const float* projmatrix, - const float* campos, - const float tan_fovx, float tan_fovy, - const int* radii, - char* geom_buffer, - char* binning_buffer, - char* img_buffer, - const float* dL_dpix, - const float* dL_invdepths, - float* dL_dmean2D, - float* dL_dconic, - float* dL_dopacity, - float* dL_dcolor, - float* dL_dinvdepth, - float* dL_dmean3D, - float* dL_dcov3D, - float* dL_dsh, - float* dL_dscale, - float* dL_drot, - bool antialiasing, - bool debug) -{ +void CudaRasterizer::Rasterizer::backward(const int P, + const int D, + const int M, + const int R, + const float* background, + const int width, + const int height, + const float* means3D, + const float* shs, + const float* colors_precomp, + const float* opacities, + const float* scales, + const float scale_modifier, + const float* rotations, + const float* cov3D_precomp, + const float* viewmatrix, + const float* projmatrix, + const float* campos, + const float tan_fovx, + const float tan_fovy, + const int* radii, + char* geom_buffer, + char* binning_buffer, + char* img_buffer, + const float* dL_dpix, + const float* dL_invdepths, + float* dL_dmean2D, + float* dL_dconic, + float* dL_dopacity, + float* dL_dcolor, + float* dL_dinvdepth, + float* dL_dmean3D, + float* dL_dcov3D, + float* dL_dsh, + float* dL_dscale, + float* dL_drot, + bool antialiasing, + bool debug) { GeometryState geomState = GeometryState::fromChunk(geom_buffer, P); BinningState binningState = BinningState::fromChunk(binning_buffer, R); ImageState imgState = ImageState::fromChunk(img_buffer, width * height); - if (radii == nullptr) - { + if (radii == nullptr) { radii = geomState.internal_radii; } diff --git a/cuda_rasterizer/rasterizer_impl.h b/cuda_rasterizer/rasterizer_impl.h index 9cb61543..60ad1274 100644 --- a/cuda_rasterizer/rasterizer_impl.h +++ b/cuda_rasterizer/rasterizer_impl.h @@ -1,15 +1,6 @@ -/* - * Copyright (C) 2023, Inria - * GRAPHDECO research group, https://team.inria.fr/graphdeco - * All rights reserved. - * - * This software is free for non-commercial, research and evaluation use - * under the terms of the LICENSE.md file. - * - * For inquiries contact george.drettakis@inria.fr - */ - #pragma once +#ifndef CUDA_RASTERIZER_IMPL_H +#define CUDA_RASTERIZER_IMPL_H #include #include @@ -20,8 +11,10 @@ namespace CudaRasterizer { template - static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment) - { + static void obtain(char*& chunk, + T*& ptr, + std::size_t count, + std::size_t alignment) { std::size_t offset = (reinterpret_cast(chunk) + alignment - 1) & ~(alignment - 1); ptr = reinterpret_cast(offset); chunk = reinterpret_cast(ptr + count); @@ -66,10 +59,11 @@ namespace CudaRasterizer }; template - size_t required(size_t P) - { + size_t required(size_t P) { char* size = nullptr; T::fromChunk(size, P); return ((size_t)size) + 128; } -}; \ No newline at end of file +}; + +#endif // CUDA_RASTERIZER_IMPL_H \ No newline at end of file diff --git a/ext.cpp b/ext.cpp index d7687795..64369517 100644 --- a/ext.cpp +++ b/ext.cpp @@ -1,14 +1,3 @@ -/* - * Copyright (C) 2023, Inria - * GRAPHDECO research group, https://team.inria.fr/graphdeco - * All rights reserved. - * - * This software is free for non-commercial, research and evaluation use - * under the terms of the LICENSE.md file. - * - * For inquiries contact george.drettakis@inria.fr - */ - #include #include "rasterize_points.h" diff --git a/rasterize_points.cu b/rasterize_points.cu index e625c19e..a7ecc657 100644 --- a/rasterize_points.cu +++ b/rasterize_points.cu @@ -1,14 +1,3 @@ -/* - * Copyright (C) 2023, Inria - * GRAPHDECO research group, https://team.inria.fr/graphdeco - * All rights reserved. - * - * This software is free for non-commercial, research and evaluation use - * under the terms of the LICENSE.md file. - * - * For inquiries contact george.drettakis@inria.fr - */ - #include #include #include @@ -33,30 +22,28 @@ std::function resizeFunctional(torch::Tensor& t) { } std::tuple -RasterizeGaussiansCUDA( - const torch::Tensor& background, - const torch::Tensor& means3D, - const torch::Tensor& colors, - const torch::Tensor& opacity, - const torch::Tensor& scales, - const torch::Tensor& rotations, - const float scale_modifier, - const torch::Tensor& cov3D_precomp, - const torch::Tensor& viewmatrix, - const torch::Tensor& projmatrix, - const float tan_fovx, - const float tan_fovy, - const int image_height, - const int image_width, - const torch::Tensor& sh, - const int degree, - const torch::Tensor& campos, - const bool prefiltered, - const bool antialiasing, - const bool debug) -{ +RasterizeGaussiansCUDA(const torch::Tensor& background, + const torch::Tensor& means3D, + const torch::Tensor& colors, + const torch::Tensor& opacity, + const torch::Tensor& scales, + const torch::Tensor& rotations, + const float scale_modifier, + const torch::Tensor& cov3D_precomp, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const float tan_fovx, + const float tan_fovy, + const int image_height, + const int image_width, + const torch::Tensor& sh, + const int degree, + const torch::Tensor& campos, + const bool prefiltered, + const bool antialiasing, + const bool debug) { if (means3D.ndimension() != 2 || means3D.size(1) != 3) { - AT_ERROR("means3D must have dimensions (num_points, 3)"); + AT_ERROR("means3D must have dimensions (num_points, 3)"); } const int P = means3D.size(0); @@ -85,15 +72,13 @@ RasterizeGaussiansCUDA( std::function imgFunc = resizeFunctional(imgBuffer); int rendered = 0; - if(P != 0) - { - int M = 0; - if(sh.size(0) != 0) - { - M = sh.size(1); - } + if (P != 0) { + int M = 0; + if (sh.size(0) != 0) { + M = sh.size(1); + } - rendered = CudaRasterizer::Rasterizer::forward( + rendered = CudaRasterizer::Rasterizer::forward( geomFunc, binningFunc, imgFunc, @@ -124,102 +109,98 @@ RasterizeGaussiansCUDA( } std::tuple - RasterizeGaussiansBackwardCUDA( - const torch::Tensor& background, - const torch::Tensor& means3D, - const torch::Tensor& radii, - const torch::Tensor& colors, - const torch::Tensor& opacities, - const torch::Tensor& scales, - const torch::Tensor& rotations, - const float scale_modifier, - const torch::Tensor& cov3D_precomp, - const torch::Tensor& viewmatrix, - const torch::Tensor& projmatrix, - const float tan_fovx, - const float tan_fovy, - const torch::Tensor& dL_dout_color, - const torch::Tensor& dL_dout_invdepth, - const torch::Tensor& sh, - const int degree, - const torch::Tensor& campos, - const torch::Tensor& geomBuffer, - const int R, - const torch::Tensor& binningBuffer, - const torch::Tensor& imageBuffer, - const bool antialiasing, - const bool debug) -{ - const int P = means3D.size(0); - const int H = dL_dout_color.size(1); - const int W = dL_dout_color.size(2); +RasterizeGaussiansBackwardCUDA(const torch::Tensor& background, + const torch::Tensor& means3D, + const torch::Tensor& radii, + const torch::Tensor& colors, + const torch::Tensor& opacities, + const torch::Tensor& scales, + const torch::Tensor& rotations, + const float scale_modifier, + const torch::Tensor& cov3D_precomp, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const float tan_fovx, + const float tan_fovy, + const torch::Tensor& dL_dout_color, + const torch::Tensor& dL_dout_invdepth, + const torch::Tensor& sh, + const int degree, + const torch::Tensor& campos, + const torch::Tensor& geomBuffer, + const int R, + const torch::Tensor& binningBuffer, + const torch::Tensor& imageBuffer, + const bool antialiasing, + const bool debug) { + const int P = means3D.size(0); + const int H = dL_dout_color.size(1); + const int W = dL_dout_color.size(2); - int M = 0; - if(sh.size(0) != 0) - { - M = sh.size(1); - } + int M = 0; + if (sh.size(0) != 0) { + M = sh.size(1); + } - torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options()); - torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options()); - torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options()); - torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options()); - torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options()); - torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options()); - torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options()); - torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options()); - torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options()); - torch::Tensor dL_dinvdepths = torch::zeros({0, 1}, means3D.options()); + torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options()); + torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options()); + torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options()); + torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options()); + torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options()); + torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options()); + torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options()); + torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options()); + torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options()); + torch::Tensor dL_dinvdepths = torch::zeros({0, 1}, means3D.options()); - float* dL_dinvdepthsptr = nullptr; - float* dL_dout_invdepthptr = nullptr; - if(dL_dout_invdepth.size(0) != 0) - { - dL_dinvdepths = torch::zeros({P, 1}, means3D.options()); - dL_dinvdepths = dL_dinvdepths.contiguous(); - dL_dinvdepthsptr = dL_dinvdepths.data(); - dL_dout_invdepthptr = dL_dout_invdepth.data(); - } + float* dL_dinvdepthsptr = nullptr; + float* dL_dout_invdepthptr = nullptr; + if (dL_dout_invdepth.size(0) != 0) { + dL_dinvdepths = torch::zeros({P, 1}, means3D.options()); + dL_dinvdepths = dL_dinvdepths.contiguous(); + dL_dinvdepthsptr = dL_dinvdepths.data(); + dL_dout_invdepthptr = dL_dout_invdepth.data(); + } - if(P != 0) - { - CudaRasterizer::Rasterizer::backward(P, degree, M, R, - background.contiguous().data(), - W, H, - means3D.contiguous().data(), - sh.contiguous().data(), - colors.contiguous().data(), - opacities.contiguous().data(), - scales.data_ptr(), - scale_modifier, - rotations.data_ptr(), - cov3D_precomp.contiguous().data(), - viewmatrix.contiguous().data(), - projmatrix.contiguous().data(), - campos.contiguous().data(), - tan_fovx, - tan_fovy, - radii.contiguous().data(), - reinterpret_cast(geomBuffer.contiguous().data_ptr()), - reinterpret_cast(binningBuffer.contiguous().data_ptr()), - reinterpret_cast(imageBuffer.contiguous().data_ptr()), - dL_dout_color.contiguous().data(), - dL_dout_invdepthptr, - dL_dmeans2D.contiguous().data(), - dL_dconic.contiguous().data(), - dL_dopacity.contiguous().data(), - dL_dcolors.contiguous().data(), - dL_dinvdepthsptr, - dL_dmeans3D.contiguous().data(), - dL_dcov3D.contiguous().data(), - dL_dsh.contiguous().data(), - dL_dscales.contiguous().data(), - dL_drotations.contiguous().data(), - antialiasing, - debug); - } + if (P != 0) { + CudaRasterizer::Rasterizer::backward(P, degree, M, R, + background.contiguous().data(), + W, + H, + means3D.contiguous().data(), + sh.contiguous().data(), + colors.contiguous().data(), + opacities.contiguous().data(), + scales.data_ptr(), + scale_modifier, + rotations.data_ptr(), + cov3D_precomp.contiguous().data(), + viewmatrix.contiguous().data(), + projmatrix.contiguous().data(), + campos.contiguous().data(), + tan_fovx, + tan_fovy, + radii.contiguous().data(), + reinterpret_cast(geomBuffer.contiguous().data_ptr()), + reinterpret_cast(binningBuffer.contiguous().data_ptr()), + reinterpret_cast(imageBuffer.contiguous().data_ptr()), + dL_dout_color.contiguous().data(), + dL_dout_invdepthptr, + dL_dmeans2D.contiguous().data(), + dL_dconic.contiguous().data(), + dL_dopacity.contiguous().data(), + dL_dcolors.contiguous().data(), + dL_dinvdepthsptr, + dL_dmeans3D.contiguous().data(), + dL_dcov3D.contiguous().data(), + dL_dsh.contiguous().data(), + dL_dscales.contiguous().data(), + dL_drotations.contiguous().data(), + antialiasing, + debug); + } - return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dsh, dL_dscales, dL_drotations); + return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dsh, dL_dscales, dL_drotations); } torch::Tensor markVisible( @@ -227,18 +208,18 @@ torch::Tensor markVisible( torch::Tensor& viewmatrix, torch::Tensor& projmatrix) { - const int P = means3D.size(0); - - torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool)); + const int P = means3D.size(0); + + torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool)); - if(P != 0) - { - CudaRasterizer::Rasterizer::markVisible(P, - means3D.contiguous().data(), - viewmatrix.contiguous().data(), - projmatrix.contiguous().data(), - present.contiguous().data()); - } + if (P != 0) { + CudaRasterizer::Rasterizer::markVisible( + P, + means3D.contiguous().data(), + viewmatrix.contiguous().data(), + projmatrix.contiguous().data(), + present.contiguous().data()); + } - return present; + return present; } diff --git a/rasterize_points.h b/rasterize_points.h index 82cbd4f2..bb8deb08 100644 --- a/rasterize_points.h +++ b/rasterize_points.h @@ -1,71 +1,62 @@ -/* - * Copyright (C) 2023, Inria - * GRAPHDECO research group, https://team.inria.fr/graphdeco - * All rights reserved. - * - * This software is free for non-commercial, research and evaluation use - * under the terms of the LICENSE.md file. - * - * For inquiries contact george.drettakis@inria.fr - */ - #pragma once +#ifndef RASTERIZER_POINTS_H +#define RASTERIZER_POINTS_H + #include #include #include #include std::tuple -RasterizeGaussiansCUDA( - const torch::Tensor& background, - const torch::Tensor& means3D, - const torch::Tensor& colors, - const torch::Tensor& opacity, - const torch::Tensor& scales, - const torch::Tensor& rotations, - const float scale_modifier, - const torch::Tensor& cov3D_precomp, - const torch::Tensor& viewmatrix, - const torch::Tensor& projmatrix, - const float tan_fovx, - const float tan_fovy, - const int image_height, - const int image_width, - const torch::Tensor& sh, - const int degree, - const torch::Tensor& campos, - const bool prefiltered, - const bool antialiasing, - const bool debug); +RasterizeGaussiansCUDA(const torch::Tensor& background, + const torch::Tensor& means3D, + const torch::Tensor& colors, + const torch::Tensor& opacity, + const torch::Tensor& scales, + const torch::Tensor& rotations, + const float scale_modifier, + const torch::Tensor& cov3D_precomp, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const float tan_fovx, + const float tan_fovy, + const int image_height, + const int image_width, + const torch::Tensor& sh, + const int degree, + const torch::Tensor& campos, + const bool prefiltered, + const bool antialiasing, + const bool debug); std::tuple - RasterizeGaussiansBackwardCUDA( - const torch::Tensor& background, - const torch::Tensor& means3D, - const torch::Tensor& radii, - const torch::Tensor& colors, - const torch::Tensor& scales, - const torch::Tensor& opacities, - const torch::Tensor& rotations, - const float scale_modifier, - const torch::Tensor& cov3D_precomp, - const torch::Tensor& viewmatrix, - const torch::Tensor& projmatrix, - const float tan_fovx, - const float tan_fovy, - const torch::Tensor& dL_dout_color, - const torch::Tensor& dL_dout_invdepth, - const torch::Tensor& sh, - const int degree, - const torch::Tensor& campos, - const torch::Tensor& geomBuffer, - const int R, - const torch::Tensor& binningBuffer, - const torch::Tensor& imageBuffer, - const bool antialiasing, - const bool debug); + RasterizeGaussiansBackwardCUDA(const torch::Tensor& background, + const torch::Tensor& means3D, + const torch::Tensor& radii, + const torch::Tensor& colors, + const torch::Tensor& scales, + const torch::Tensor& opacities, + const torch::Tensor& rotations, + const float scale_modifier, + const torch::Tensor& cov3D_precomp, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const float tan_fovx, + const float tan_fovy, + const torch::Tensor& dL_dout_color, + const torch::Tensor& dL_dout_invdepth, + const torch::Tensor& sh, + const int degree, + const torch::Tensor& campos, + const torch::Tensor& geomBuffer, + const int R, + const torch::Tensor& binningBuffer, + const torch::Tensor& imageBuffer, + const bool antialiasing, + const bool debug); -torch::Tensor markVisible( - torch::Tensor& means3D, - torch::Tensor& viewmatrix, - torch::Tensor& projmatrix); +torch::Tensor markVisible(torch::Tensor& means3D, + torch::Tensor& viewmatrix, + torch::Tensor& projmatrix); + +#endif // RASTERIZER_POINTS_H \ No newline at end of file