diff --git a/falkon/csrc/cuda/lauum.cu b/falkon/csrc/cuda/lauum.cu index a7dcb5d9..326d6cb3 100644 --- a/falkon/csrc/cuda/lauum.cu +++ b/falkon/csrc/cuda/lauum.cu @@ -10,7 +10,7 @@ #define BLOCK_SIZE 32 //#define DEBUG - +/* template __global__ void upper_cuda_lauum_ker(const scalar_t* __restrict__ in, @@ -75,6 +75,166 @@ void upper_cuda_lauum_ker(const scalar_t* __restrict__ in, out[row + col * out_stride] = accumulator; } } +*/ + + +/* + * Definitions for lauum_upper_ker_tri_tiled + */ +#define BLK_N 96 +#define BLK_K 16 +#define DIM_READ_X 16 +#define DIM_READ_Y DIM_READ_X +#define DIM_COMP_X 16 +#define DIM_COMP_Y DIM_COMP_X +#define THR_N ( BLK_N / DIM_COMP_X ) + + +/* + * Triangular, tiled implementation with double buffered registers + * and thread coarsening (each thread computes multiple output + * elements). Quite heavily inspired by GEMM in MAGMA. + */ +template +__global__ +void lauum_upper_ker_tri_tiled_adv(const scalar_t* __restrict__ in, + scalar_t* __restrict__ out, + const int size, + const int in_stride, + const int out_stride, + const int grid_size) +{ + const int2 p = tri_index_lower(blockIdx.x); // lower and upper are mixed up. + const int tx = threadIdx.x; + const int ty = threadIdx.y; + + // DIM_COMP_X, DIM_COMP_Y Size of the thread block for computing output + // DIM_READ_X, DIM_READ_Y Size of thread blocks for reading A, B + // BLK_K, BLK_N + // Multiplication is between two matrices of shape N, K. + // The first dimension (N) is also referred to as X, the second (K=Y). + __shared__ scalar_t sA[BLK_K][BLK_N]; + __shared__ scalar_t sB[BLK_K][BLK_N + 1]; + + scalar_t rC[THR_N * THR_N]; // 36 + scalar_t rA[THR_N]; // 6 + scalar_t rB[THR_N]; // 6 + + scalar_t ra[BLK_N / DIM_READ_X]; // 6 + scalar_t rb[BLK_N / DIM_READ_X]; // 6 + + // Total work (output size) of the thread block is BLK_N * BLK_N, but + // there are only DIM_COMP_X * DIM_COMP_Y threads. So each thread works on + // more than a single output. + // The thread-ids are indices of the current thread within the BLK_N, BLK_N + // work block. Note ty goes horizontally, tx vertically. + const int tid_global = DIM_COMP_X * ty + tx; + + const int tid_x = tid_global % DIM_READ_X; + const int tid_y = tid_global / DIM_READ_X; + + int i, j, k, ki; + int jj; + int col; + const int row_a = p.x * BLK_N + tid_x; + const int row_b = p.y * BLK_N + tid_x; + + // Zero-out rC + # pragma unroll + for (i = 0; i < THR_N * THR_N; i++) { + rC[i] = 0; + } + + // Global -> Shared (sA, sB) + col = p.y * BLK_N + tid_y; + # pragma unroll + for (i = 0; i < BLK_K; i += DIM_READ_Y) { + # pragma unroll + for (j = 0; j < BLK_N; j += DIM_READ_X) { + if (row_a + j <= col) { + sA[tid_y + i][tid_x + j] = in[min(row_a + j + col * in_stride, size * in_stride - 1)]; + } else { + sA[tid_y + i][tid_x + j] = 0; + } + if (row_b + j <= col) { + sB[tid_y + i][tid_x + j] = in[min(row_b + j + col * in_stride, size * in_stride - 1)]; + } else { + sB[tid_y + i][tid_x + j] = 0; + } + } + col += DIM_READ_Y; + } + __syncthreads(); + + for (k = p.y * BLK_N + BLK_K; k < size; k += BLK_K) { + // Load global -> registers + col = k + tid_y; + # pragma unroll + for (j = 0, jj = 0; jj < BLK_N; j++, jj += DIM_READ_X) { + if (row_a + jj <= col) { + ra[j] = in[min(row_a + jj + col * in_stride, size * in_stride - 1)]; + } else { + ra[j] = 0; + } + if (row_b + jj <= col) { + rb[j] = in[min(row_b + jj + col * in_stride, size * in_stride - 1)]; + } else { + rb[j] = 0; + } + } + // Multiply + # pragma unroll + for (ki = 0; ki < BLK_K; ki++) { + // shared -> registers + # pragma unroll + for (i = 0; i < THR_N; i++) { + rA[i] = sA[ki][i * DIM_COMP_X + tx]; + rB[i] = sB[ki][i * DIM_COMP_Y + ty]; + } + + // Compute + # pragma unroll + for (i = 0; i < THR_N * THR_N; i++) { + rC[i] += rA[i / THR_N] * rB[i % THR_N]; + } + } + __syncthreads(); + // Load registers -> shared + # pragma unroll + for (j = 0, jj = 0; j < BLK_N; j += DIM_READ_X, jj++) { + sA[tid_y][tid_x + j] = ra[jj]; + sB[tid_y][tid_x + j] = rb[jj]; + } + __syncthreads(); + } + // Multiply last block + # pragma unroll + for (ki = 0; ki < BLK_K; ki++) { + if (ki >= size - k + BLK_K) + break; + // shared -> registers + # pragma unroll + for (i = 0; i < THR_N; i++) { + rA[i] = sA[ki][i * DIM_COMP_X + tx]; + rB[i] = sB[ki][i * DIM_COMP_Y + ty]; + } + // Compute + # pragma unroll + for (i = 0; i < THR_N * THR_N; i++) { + rC[i] += rA[i / THR_N] * rB[i % THR_N]; + } + } + + col = p.y * BLK_N + tid_y; + # pragma unroll + for (i = 0; i < THR_N * THR_N; i++) { + if ((row_a + (i / THR_N) * DIM_COMP_X) <= (col + (i % THR_N) * DIM_COMP_Y) && (col + (i % THR_N) * DIM_COMP_Y) < size) { + out[row_a + (i / THR_N) * DIM_COMP_X + (col + (i % THR_N) * DIM_COMP_Y) * out_stride] = rC[i]; + } + } +} + + template @@ -148,22 +308,25 @@ torch::Tensor lauum_cuda(const int n, const torch::Tensor &A, const int lda, tor const auto in_stride = lda; const auto out_stride = ldb; - // Setup CUDA grid dimensions: - // grid is 1D, so that we can only consider triangularly-appropriate tiles - // blocks are 2D, with a fixed block size - const int grid_height = ceildiv(size, BLOCK_SIZE); - const dim3 dimGrid(grid_height * (grid_height + 1) / 2, 1); - const dim3 dimBlock(BLOCK_SIZE, BLOCK_SIZE); AT_DISPATCH_FLOATING_TYPES(scalar_type, "dispatch_lauum_cuda", [&] { at::DeviceGuard g(A.device()); at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); if (lower) { + // Setup CUDA grid dimensions: + // grid is 1D, so that we can only consider triangularly-appropriate tiles + // blocks are 2D, with a fixed block size + const int grid_height = ceildiv(size, BLOCK_SIZE); + const dim3 dimGrid(grid_height * (grid_height + 1) / 2, 1); + const dim3 dimBlock(BLOCK_SIZE, BLOCK_SIZE); lower_cuda_lauum_ker<<>>( A.data_ptr(), B.data_ptr(), size, in_stride, out_stride, grid_height); } else { - upper_cuda_lauum_ker<<>>( + const int grid_height = ceildiv(size, BLK_N); + const dim3 dimGrid(grid_height * (grid_height + 1) / 2, 1); + const dim3 dimBlock(DIM_COMP_X, DIM_COMP_Y); + lauum_upper_ker_tri_tiled_adv<<>>( A.data_ptr(), B.data_ptr(), size, in_stride, out_stride, grid_height); } });