Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix miscs
  • Loading branch information
A-nnonymous committed Dec 15, 2025
commit a1b3234db36c83146d99407780203681c08fc4b1
27 changes: 13 additions & 14 deletions paddle/phi/kernels/impl/matmul_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ void MatmulGradKernel(const Context& dev_ctx,
is_x_been_broadcasted = x_ndim < y_ndim;
is_y_been_broadcasted = !is_x_been_broadcasted;
} else {
#pragma unroll
for (int i = 0; i < ndim; i++) {
if (x_dims[i] != y_dims[i]) {
is_x_been_broadcasted = x_dims[i] < y_dims[i];
Expand Down Expand Up @@ -463,15 +462,15 @@ void MatmulGradKernel(const Context& dev_ctx,
phi::TransposeLast2Dim<T>(dev_ctx, out_grad);
DenseTensor y_conj_processed =
phi::TransposeLast2Dim<T>(dev_ctx, y_conj);
size_t BN = 1;
int64_t BN = 1;
for (int i = 0; i < ndim - 1; i++) {
BN *= y_dims[i];
}
std::vector<std::int64_t> out_grad_2d_dim{BN, dout_dims[ndim - 2]};
std::vector<std::int64_t> y_conj_2d_dim{BN, y_dims[y_ndim - 2]};

out_grad_processed->Resize(common::make_ddim(out_grad_2d_dim));
y_conj_processed->Resize(common::make_ddim(y_conj_2d_dim));
out_grad_processed.Resize(common::make_ddim(out_grad_2d_dim));
y_conj_processed.Resize(common::make_ddim(y_conj_2d_dim));
// 2D x 2D -> 2D
MatMulFunction<Context, T>(dev_ctx,
out_grad_processed,
Expand All @@ -487,9 +486,9 @@ void MatmulGradKernel(const Context& dev_ctx,
for (int i = 0; i < ndim - 2; i++) {
x_grad_dim[i] = 1;
}
x_grad_dims[ndim - 2] = dx_help.dims()[0];
x_grad_dims[ndim - 1] = dx_help.dims()[1];
dx_help->Resize(common::make_ddim(x_grad_dim));
x_grad_dim[ndim - 2] = dx_help.dims()[0];
x_grad_dim[ndim - 1] = dx_help.dims()[1];
dx_help.Resize(common::make_ddim(x_grad_dim));

} else {
MatMulFunction<Context, T>(dev_ctx,
Expand All @@ -507,7 +506,7 @@ void MatmulGradKernel(const Context& dev_ctx,
// Once y been broadcasted, we introduce a new aggregate dim
// original: [B, M, K] x [B, M, N] -> [B, K, N] -(reduceB)-> [K, N]
// new: [BM, K]' x [BM, N] -> [K, N]
size_t BM = 1;
int64_t BM = 1;
for (int i = 0; i < ndim - 1; i++) {
BM *= x_dims[i];
}
Expand All @@ -516,8 +515,8 @@ void MatmulGradKernel(const Context& dev_ctx,

DenseTensor out_grad_processed = out_grad;
DenseTensor x_conj_processed = x_conj;
out_grad_processed->Resize(common::make_ddim(out_grad_2d_dim));
x_conj_processed->Resize(common::make_ddim(x_conj_2d_dim));
out_grad_processed.Resize(common::make_ddim(out_grad_2d_dim));
x_conj_processed.Resize(common::make_ddim(x_conj_2d_dim));

MatMulFunction<Context, T>(dev_ctx,
x_conj_processed,
Expand All @@ -528,13 +527,13 @@ void MatmulGradKernel(const Context& dev_ctx,
true,
false);
// make legacy reduce logic happy
std::vector<std::int64_t> x_grad_dim(ndim);
std::vector<std::int64_t> y_grad_dim(ndim);
for (int i = 0; i < ndim - 2; i++) {
y_grad_dim[i] = 1;
}
y_grad_dims[ndim - 2] = dy_help.dims()[0];
y_grad_dims[ndim - 1] = dy_help.dims()[1];
dy_help->Resize(common::make_ddim(y_grad_dim));
y_grad_dim[ndim - 2] = dy_help.dims()[0];
y_grad_dim[ndim - 1] = dy_help.dims()[1];
dy_help.Resize(common::make_ddim(y_grad_dim));

} else {
MatMulFunction<Context, T>(dev_ctx,
Expand Down
Loading