Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 185 additions & 32 deletions paddle/phi/kernels/impl/matmul_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License. */
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/phi/kernels/gpu/reduce.h"
#endif
COMMON_DECLARE_bool(use_legacy_gemm);

namespace phi {

Expand Down Expand Up @@ -98,6 +99,25 @@ static DenseTensor FoldHeadAndLastDims(const Context& dev_ctx,
return output;
}

#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) && !defined(_WIN32)
// Reshape a rank-3 tensor from B x M x N to (B * N) x M.
// In order to perform [M, BN] x [BN, K] -> [M, K] to save reduce cost
// Avoiding [1,0,2] permute for better performance
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template <typename Context, typename T>
static DenseTensor FoldBatchIntoAggregation(const Context& dev_ctx,
const DenseTensor& input) {
auto in_dims = input.dims();
if (in_dims.size() != 3) {
return input;
}
DenseTensor output = phi::TransposeLast2Dim<T>(dev_ctx, input);
output.Resize({in_dims[0] * in_dims[2], in_dims[1]});
return output;
}
#endif

template <typename Context, typename T>
typename std::enable_if<!std::is_integral<T>::value>::type MatMul(
const Context& dev_ctx,
Expand Down Expand Up @@ -204,20 +224,43 @@ void CalcInputGrad(const Context& dev_ctx,
if (out == nullptr) return;
bool need_combine =
(a.dims().size() == 3 || b.dims().size() == 3) && out->dims().size() == 2;
if (!need_combine) {
MatMul<Context, T>(dev_ctx, a, trans_a, b, trans_b, out, flag);
} else {
MatMul<Context, T>(
dev_ctx,
is_fold_init_dims_a ? FoldInitDims(a)
: FoldHeadAndLastDims<Context, T>(dev_ctx, a),
trans_a,
is_fold_init_dims_b ? FoldInitDims(b)
: FoldHeadAndLastDims<Context, T>(dev_ctx, b),
trans_b,
out,
flag);

DenseTensor a_processed = a, b_processed = b;
bool trans_a_processed = trans_a, trans_b_processed = trans_b;
if (need_combine) {
#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) && !defined(_WIN32)
if (!FLAGS_use_legacy_gemm) {
a_processed = is_fold_init_dims_a
? FoldInitDims(a)
: FoldBatchIntoAggregation<Context, T>(dev_ctx, a);
b_processed = is_fold_init_dims_b
? FoldInitDims(b)
: FoldBatchIntoAggregation<Context, T>(dev_ctx, b);
// Once we try to combine aggregation dimension to batch dimension,
// we need to flip the transpose flag
trans_a_processed = is_fold_init_dims_a ? trans_a : !trans_a;
trans_b_processed = is_fold_init_dims_b ? trans_b : !trans_b;
} else // NOLINT
#endif
{ // NOLINT
a_processed = is_fold_init_dims_a
? FoldInitDims(a)
: FoldHeadAndLastDims<Context, T>(dev_ctx, a);
b_processed = is_fold_init_dims_b
? FoldInitDims(b)
: FoldHeadAndLastDims<Context, T>(dev_ctx, b);
}
}
std::vector<std::int64_t> a_dims = common::vectorize(a_processed.dims());
std::vector<std::int64_t> b_dims = common::vectorize(b_processed.dims());
MatMulFunction<Context, T>(dev_ctx,
a_processed,
b_processed,
a_dims,
b_dims,
out,
trans_a_processed,
trans_b_processed);
}

template <typename T, typename Context>
Expand Down Expand Up @@ -264,7 +307,7 @@ void MatmulGradKernel(const Context& dev_ctx,
}

bool is_broadcast = true;
if (x_ndim <= 2 || y_ndim <= 2) {
if (y_ndim <= 2 || x_ndim <= 2) {
is_broadcast = false;
} else if (x_ndim != y_ndim) {
is_broadcast = true;
Expand All @@ -273,6 +316,29 @@ void MatmulGradKernel(const Context& dev_ctx,
x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin());
}

bool is_y_been_broadcasted = false;
bool is_x_been_broadcasted = false;
// NOTE(Pan Zhaowu): Figure out which tensor is been broadcasted,
// to combine the broadcasted dim of other tensor into aggregation dim,
// avoiding use batched gemm and saving reduction cost.
if (is_broadcast) {
if (x_ndim != y_ndim) {
is_x_been_broadcasted = x_ndim < y_ndim;
is_y_been_broadcasted = !is_x_been_broadcasted;
} else {
int64_t x_batch = 1;
int64_t y_batch = 1;
for (int i = 0; i < x_ndim - 2; ++i) {
x_batch *= x_dims[i];
}
for (int i = 0; i < y_ndim - 2; ++i) {
y_batch *= y_dims[i];
}
is_x_been_broadcasted = x_batch < y_batch;
is_y_been_broadcasted = !is_x_been_broadcasted;
}
}

// for complex
DenseTensor x_conj;
DenseTensor y_conj;
Expand Down Expand Up @@ -432,24 +498,111 @@ void MatmulGradKernel(const Context& dev_ctx,
false);
} else {
// XY: dX = GY', dY = X'G
if (dx)
MatMulFunction<Context, T>(dev_ctx,
out_grad,
y_conj,
dout_dims,
y_dims,
&dx_help,
false,
true);
if (dy)
MatMulFunction<Context, T>(dev_ctx,
x_conj,
out_grad,
x_dims,
dout_dims,
&dy_help,
true,
false);
VLOG(3)
<< "matmul grad case: transpose_x = false && transpose_y = false";
if (dx) {
#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) && !defined(_WIN32)
if (!FLAGS_use_legacy_gemm && is_x_been_broadcasted && x_ndim == 3 &&
ndim == 3) {
// Once x been broadcasted, we introduce a new aggregate dim
// original: [B, M, N] x [B, K, N]' -> [B, M, K] -(reduceB)-> [M, K]
// new: [BN, M] x [BN, K] -> [M, K]
DenseTensor out_grad_processed =
phi::TransposeLast2Dim<T>(dev_ctx, out_grad);
DenseTensor y_conj_processed =
phi::TransposeLast2Dim<T>(dev_ctx, y_conj);
int64_t BN = 1;
std::vector<std::int64_t> y_processed_dims =
common::vectorize(y_conj_processed.dims());
for (int i = 0; i < ndim - 1; i++) {
BN *= y_processed_dims[i];
}
std::vector<std::int64_t> out_grad_2d_dim{BN, dout_dims[ndim - 2]};
std::vector<std::int64_t> y_conj_2d_dim{BN, y_dims[y_ndim - 2]};

out_grad_processed.Resize(common::make_ddim(out_grad_2d_dim));
y_conj_processed.Resize(common::make_ddim(y_conj_2d_dim));
// 2D x 2D -> 2D
MatMulFunction<Context, T>(dev_ctx,
out_grad_processed,
y_conj_processed,
out_grad_2d_dim,
y_conj_2d_dim,
&dx_help,
true,
false);

// make legacy reduce logic happy
std::vector<std::int64_t> x_grad_dim(ndim);
for (int i = 0; i < ndim - 2; i++) {
x_grad_dim[i] = 1;
}
x_grad_dim[ndim - 2] = dx_help.dims()[0];
x_grad_dim[ndim - 1] = dx_help.dims()[1];
dx_help.Resize(common::make_ddim(x_grad_dim));

} else // NOLINT
#endif
{ // NOLINT
MatMulFunction<Context, T>(dev_ctx,
out_grad,
y_conj,
dout_dims,
y_dims,
&dx_help,
false,
true);
} // if is_x_been_broadcasted
} // if dx
if (dy) {
#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) && !defined(_WIN32)
if (!FLAGS_use_legacy_gemm && is_y_been_broadcasted && y_ndim == 3 &&
ndim == 3) {
// Once y been broadcasted, we introduce a new aggregate dim
// original: [B, M, K] x [B, M, N] -> [B, K, N] -(reduceB)-> [K, N]
// new: [BM, K]' x [BM, N] -> [K, N]
int64_t BM = 1;
for (int i = 0; i < ndim - 1; i++) {
BM *= x_dims[i];
}
std::vector<std::int64_t> out_grad_2d_dim{BM, dout_dims[ndim - 1]};
std::vector<std::int64_t> x_conj_2d_dim{BM, x_dims[x_ndim - 1]};

DenseTensor out_grad_processed = out_grad;
DenseTensor x_conj_processed = x_conj;
out_grad_processed.Resize(common::make_ddim(out_grad_2d_dim));
x_conj_processed.Resize(common::make_ddim(x_conj_2d_dim));

MatMulFunction<Context, T>(dev_ctx,
x_conj_processed,
out_grad_processed,
x_conj_2d_dim,
out_grad_2d_dim,
&dy_help,
true,
false);
// make legacy reduce logic happy
std::vector<std::int64_t> y_grad_dim(ndim);
for (int i = 0; i < ndim - 2; i++) {
y_grad_dim[i] = 1;
}
y_grad_dim[ndim - 2] = dy_help.dims()[0];
y_grad_dim[ndim - 1] = dy_help.dims()[1];
dy_help.Resize(common::make_ddim(y_grad_dim));

} else // NOLINT
#endif
{ // NOLINT
MatMulFunction<Context, T>(dev_ctx,
x_conj,
out_grad,
x_dims,
dout_dims,
&dy_help,
true,
false);
} // if is_y_been_broadcasted
} // if dy
}
}

Expand Down
Loading