Skip to content
Open
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
3a8c699
naive implement
WintersMontagne10335 Dec 1, 2025
bcbeb54
update
WintersMontagne10335 Dec 1, 2025
1f92a46
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 1, 2025
cb487f9
update
WintersMontagne10335 Dec 1, 2025
f582b9f
update
WintersMontagne10335 Dec 2, 2025
b63d45d
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 2, 2025
d2cf355
update
WintersMontagne10335 Dec 2, 2025
5f77217
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 2, 2025
5096c40
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 5, 2025
ff24ce3
update
WintersMontagne10335 Dec 5, 2025
e15c147
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 5, 2025
32c16f6
update
WintersMontagne10335 Dec 6, 2025
4fbc263
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 6, 2025
1953589
update
WintersMontagne10335 Dec 6, 2025
9a2de36
update
WintersMontagne10335 Dec 6, 2025
f8e0db9
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 8, 2025
1a41726
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 9, 2025
4fe497c
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 9, 2025
98e5b17
update
WintersMontagne10335 Dec 9, 2025
967b3bc
update
WintersMontagne10335 Dec 9, 2025
d54b205
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 9, 2025
2133a77
update
WintersMontagne10335 Dec 9, 2025
37ea197
update
WintersMontagne10335 Dec 9, 2025
d694f46
update
WintersMontagne10335 Dec 10, 2025
8970a99
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 10, 2025
351773b
update
WintersMontagne10335 Dec 10, 2025
5a6f478
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 11, 2025
187670b
update
WintersMontagne10335 Dec 11, 2025
4dd58d8
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 11, 2025
502d99f
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 12, 2025
4ae3219
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 12, 2025
b5e0d06
update for ci dcu
WintersMontagne10335 Dec 13, 2025
a2ad576
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 13, 2025
1d818f5
update cmakelists
WintersMontagne10335 Dec 14, 2025
254ee30
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 14, 2025
715e508
update
WintersMontagne10335 Dec 14, 2025
48601cd
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 15, 2025
b2bd396
Sink the judgment logic down to the C++ layer
WintersMontagne10335 Dec 15, 2025
22a69c6
update
WintersMontagne10335 Dec 15, 2025
719e8a9
update
WintersMontagne10335 Dec 15, 2025
0898713
update op_build_gen
WintersMontagne10335 Dec 18, 2025
b51d436
Merge remote-tracking branch 'upstream/develop' into hackathon9th06planb
WintersMontagne10335 Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ bool Pool2dGradOpInferSymbolicShape(
return true;
}

bool MaxPool2dWithDilationsGradOpInferSymbolicShape(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infer_sym这些需要开CINN模式来测

Copy link
Contributor Author

@WintersMontagne10335 WintersMontagne10335 Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/PaddlePaddle/Paddle/blob/develop/test/legacy_test/op_test.py#L2941 ,测这个需要 check_symbol_infer(默认值为True)、check_pir 都为True,check_pir 已经设置为True。

pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
SameShapeInfer(infer_context, op->result(0), op->operand_source(0));
return true;
}

bool BceLossGradOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
SameShapeInfer(infer_context, op->result(0), op->operand_source(0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv2dGrad)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatmulGrad)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(DepthwiseConv2dGrad)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pool2dGrad)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool2dWithDilationsGrad)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLossGrad)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLossGrad_)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,132 @@ symbol::ShapeOrDataDimExprs Pool2dRawInferSymbolicShape(

return output_shape_or_data;
}

symbol::ShapeOrDataDimExprs MaxPool2dWithDilationsRawInferSymbolicShape(
pir::Operation *op,
const std::vector<symbol::DimExpr> &kernel_size,
pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));

const auto &x_dims = x_shape_or_data.shape();
PADDLE_ENFORCE_EQ(
x_dims.size() == 4 || x_dims.size() == 5,
true,
common::errors::InvalidArgument(
"the input of Op(pool) should be 4-D or 5-D Tensor. But "
"received: %u-D Tensor.",
x_dims.size()));

PADDLE_ENFORCE_EQ(x_dims.size() - kernel_size.size(),
2U,
common::errors::InvalidArgument(
"the rank of input minus the size of kernel_size "
"must be equal to 2 in Op(pool). "
"But received: the rank of input is %d and the "
"rank of kernel_size is %d.",
x_dims.size(),
kernel_size.size()));

std::vector<int> strides = [&]() {
std::vector<int> res;
const auto &stride_attr =
op->attributes().at("strides").dyn_cast<pir::ArrayAttribute>();
for (size_t i = 0; i < stride_attr.size(); i++) {
res.emplace_back(
stride_attr.at(i).dyn_cast<pir::Int64Attribute>().data());
}
return res;
}();

PADDLE_ENFORCE_EQ(
kernel_size.size(),
strides.size(),
common::errors::InvalidArgument(
"the rank of kernel_size and strides in Op(pool) must be equal. "
"But received: the rank of kernel_size is %d and the rank of stride "
"is %d.",
kernel_size.size(),
strides.size()));

const std::string &data_format =
op->attribute<pir::StrAttribute>("data_format").AsString();
const bool channel_last = data_format == "NHWC" || data_format == "NDHWC";

const auto &data_dims = [&]() -> std::vector<symbol::DimExpr> {
if (channel_last) {
return std::vector<symbol::DimExpr>(x_dims.begin() + 1, x_dims.end() - 1);
} else {
return std::vector<symbol::DimExpr>(x_dims.begin() + 2, x_dims.end());
}
}();

bool global_pooling =
op->attribute<pir::BoolAttribute>("global_pooling").data();
std::string padding_algorithm =
op->attribute<pir::StrAttribute>("padding_algorithm").AsString();

const auto &real_paddings = [&]() -> std::vector<symbol::DimExpr> {
std::vector<int> paddings;
const auto &padding_attr =
op->attributes().at("paddings").dyn_cast<pir::ArrayAttribute>();
for (size_t i = 0; i < padding_attr.size(); i++) {
paddings.emplace_back(
padding_attr.at(i).dyn_cast<pir::Int64Attribute>().data());
}
return GetRealPadding(paddings,
global_pooling,
false,
padding_algorithm,
data_dims,
strides,
kernel_size

);
}();

const auto &real_kernel_size = [&]() -> std::vector<symbol::DimExpr> {
if (global_pooling) {
return data_dims;
}
return kernel_size;
}();

const auto &output_shape_or_data = [&]() -> symbol::ShapeOrDataDimExprs {
std::vector<symbol::DimExpr> output_shape;
bool ceil_mode = op->attribute<pir::BoolAttribute>("ceil_mode").data();
for (size_t i = 0; i < data_dims.size(); ++i) {
symbol::DimExpr stride_dimexpr{strides[i]};
symbol::DimExpr one_dimexpr{1};
if (!ceil_mode) {
output_shape.emplace_back((data_dims[i] - real_kernel_size[i] +
real_paddings[2 * i] +
real_paddings[2 * i + 1]) /
stride_dimexpr +
one_dimexpr);
} else {
output_shape.emplace_back(
(data_dims[i] - real_kernel_size[i] + real_paddings[2 * i] +
real_paddings[2 * i + 1] + stride_dimexpr - one_dimexpr) /
stride_dimexpr +
one_dimexpr);
}
}

// output_N = input_N
output_shape.insert(output_shape.begin(), x_dims[0]);
// output_C = input_C
if (channel_last) {
output_shape.push_back(x_dims[x_dims.size() - 1]);
} else {
output_shape.insert(output_shape.begin() + 1, x_dims[1]);
}
return symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(output_shape)};
}();

return output_shape_or_data;
}
} // namespace

namespace paddle::dialect {
Expand Down Expand Up @@ -2947,6 +3073,19 @@ bool Pool2dOpInferSymbolicShape(pir::Operation *op,
return true;
}

bool MaxPool2dWithDilationsOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &kernel_size_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
const auto &kernel_size =
paddle::dialect::details::GetExprVecFromData(kernel_size_shape_or_data);
infer_context->SetShapeOrDataForValue(
op->result(0),
MaxPool2dWithDilationsRawInferSymbolicShape(
op, kernel_size, infer_context));
return true;
}

bool Pool3dOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
std::vector<int64_t> kernel_size_ =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(PartialSum)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad3d)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pool2d)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool2dWithDilations)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pool3d)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pool)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod)
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,20 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
dx->share_meta(x);
}

void MaxPool2dWithIndexGradInferMeta(const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& dout,
const std::vector<int>& kernel_size,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
bool global_pooling,
bool adaptive,
bool ceil_mode,
MetaTensor* dx) {
dx->share_meta(x);
}

void MedianGradInferMeta(const MetaTensor& x,
const MetaTensor& median_data,
const MetaTensor& median_index,
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,19 @@ PADDLE_API void MaxPoolWithIndexGradInferMeta(
bool ceil_mode,
MetaTensor* dx);

PADDLE_API void MaxPool2dWithIndexGradInferMeta(
const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& dout,
const std::vector<int>& kernel_size,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
bool global_pooling,
bool adaptive,
bool ceil_mode,
MetaTensor* dx);

PADDLE_API void MedianGradInferMeta(const MetaTensor& x,
const MetaTensor& median_data,
const MetaTensor& median_index,
Expand Down
49 changes: 49 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2770,6 +2770,29 @@ void MaxOutInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void MaxPool2dWithIndexInferMeta(const MetaTensor& x,
const std::vector<int>& kernel_size,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
bool global_pooling,
bool adaptive,
bool ceil_mode,
MetaTensor* out,
MetaTensor* mask,
MetaConfig config) {
MaxPoolWithIndexInferMeta(x,
kernel_size,
strides,
paddings,
global_pooling,
adaptive,
ceil_mode,
out,
mask,
config);
}

void MaxPoolWithIndexInferMeta(const MetaTensor& x,
const std::vector<int>& kernel_size,
const std::vector<int>& strides,
Expand Down Expand Up @@ -3794,6 +3817,32 @@ void Pool2DInferMeta(const MetaTensor& x,
}
}

void MaxPool2DWithDilationsInferMeta(const MetaTensor& x,
const IntArray& kernel_size,
const std::vector<int64_t>& strides,
const std::vector<int64_t>& paddings,
const std::vector<int64_t>& dilations,
bool ceil_mode,
const std::string& data_format,
bool global_pooling,
const std::string& padding_algorithm,
MetaTensor* out,
MetaConfig config) {
Pool2DInferMeta(x,
kernel_size,
strides,
paddings,
ceil_mode,
false,
data_format,
"max",
global_pooling,
false,
padding_algorithm,
out,
config);
}

void PSendInferMeta(const MetaTensor& x, int peer) {
LOG(INFO) << "SendBaseInferMeta begin";
PADDLE_ENFORCE_GE(
Expand Down
25 changes: 25 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,18 @@ PADDLE_API void MaxPoolWithIndexInferMeta(const MetaTensor& x,
MetaTensor* mask,
MetaConfig config = MetaConfig());

PADDLE_API void MaxPool2dWithIndexInferMeta(const MetaTensor& x,
const std::vector<int>& kernel_size,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
bool global_pooling,
bool adaptive,
bool ceil_mode,
MetaTensor* out,
MetaTensor* mask,
MetaConfig config = MetaConfig());

PADDLE_API void MaxPoolV2InferMeta(const MetaTensor& x,
const std::vector<int>& kernel_size,
const std::vector<int>& strides,
Expand Down Expand Up @@ -625,6 +637,19 @@ PADDLE_API void Pool2DInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

PADDLE_API void MaxPool2DWithDilationsInferMeta(
const MetaTensor& x,
const IntArray& kernel_size,
const std::vector<int64_t>& strides,
const std::vector<int64_t>& paddings,
const std::vector<int64_t>& dilations,
bool ceil_mode,
const std::string& data_format,
bool global_pooling,
const std::string& padding_algorithm,
MetaTensor* out,
MetaConfig config = MetaConfig());

PADDLE_API void PSendInferMeta(const MetaTensor& x, int peer);

PADDLE_API void PSendArrayInferMeta(const MetaTensor& x, int peer);
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/cpu/pool_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ PD_REGISTER_KERNEL(pool2d_grad,
float,
double,
phi::float16) {}
PD_REGISTER_KERNEL(max_pool2d_with_dilations_grad,
CPU,
ALL_LAYOUT,
phi::MaxPool2DWithDilationsGradKernel,
float,
double,
phi::float16) {}
PD_REGISTER_KERNEL(
lp_pool2d_grad, CPU, ALL_LAYOUT, phi::LPPool2dGradKernel, float, double) {}
PD_REGISTER_KERNEL(pool2d_double_grad,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/cpu/pool_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@

PD_REGISTER_KERNEL(
pool2d, CPU, ALL_LAYOUT, phi::Pool2dKernel, float, double, phi::float16) {}
PD_REGISTER_KERNEL(max_pool2d_with_dilations,
CPU,
ALL_LAYOUT,
phi::MaxPool2DWithDilationsKernel,
float,
double,
phi::float16) {}
PD_REGISTER_KERNEL(
lp_pool2d, CPU, ALL_LAYOUT, phi::LPPool2dKernel, float, double) {}
PD_REGISTER_KERNEL(max_pool2d_with_index,
Expand Down
Loading
Loading