Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update the Python part
  • Loading branch information
WintersMontagne10335 committed Nov 23, 2025
commit 0c6a3c908a6f87dc731ce3940bb9c072317b6b4c
66 changes: 45 additions & 21 deletions python/paddle/nn/functional/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ def avg_pool1d(
stride = convert_to_list(stride, 1, 'pool_stride')
stride = [1, *stride]

dilation = convert_to_list(1, 2, 'pool_dilation')

_check_value_limitation(kernel_size, "kernel_size", min_limit=1e-3)
_check_value_limitation(stride, "stride", min_limit=1e-3)

Expand All @@ -281,6 +283,7 @@ def avg_pool1d(
kernel_size,
stride,
padding,
dilation,
ceil_mode,
exclusive,
data_format,
Expand All @@ -307,6 +310,7 @@ def avg_pool1d(
"global_pooling": False,
"strides": stride,
"paddings": padding,
"dilations": dilation,
"padding_algorithm": padding_algorithm,
"use_cudnn": True,
"ceil_mode": ceil_mode,
Expand Down Expand Up @@ -395,12 +399,15 @@ def avg_pool2d(
padding, 2, channel_last, ceil_mode=ceil_mode
)

dilation = convert_to_list(1, 2, 'pool_dilation')

if in_dynamic_or_pir_mode():
output = _C_ops.pool2d(
x,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
exclusive,
data_format,
Expand Down Expand Up @@ -433,6 +440,7 @@ def avg_pool2d(
"global_pooling": False,
"strides": stride,
"paddings": padding,
"dilations": dilation,
"padding_algorithm": padding_algorithm,
"use_cudnn": True,
"ceil_mode": ceil_mode,
Expand Down Expand Up @@ -645,13 +653,15 @@ def max_pool1d(
padding, 1, ceil_mode=ceil_mode
)

dilation = convert_to_list(1, 2, 'pool_dilation')

# use 2d to implement 1d should expand padding in advance.
padding = _expand_low_nd_padding(padding)

if in_dynamic_or_pir_mode():
if return_mask:
pool_out = _C_ops.max_pool2d_with_index(
x, kernel_size, stride, padding, False, False, ceil_mode
x, kernel_size, stride, padding, dilation, False, False, ceil_mode
)
return (
(squeeze(pool_out[0], [2]), squeeze(pool_out[1], [2]))
Expand All @@ -664,6 +674,7 @@ def max_pool1d(
kernel_size,
stride,
padding,
dilation,
ceil_mode,
True,
data_format,
Expand Down Expand Up @@ -693,6 +704,7 @@ def max_pool1d(
"global_pooling": False,
"strides": stride,
"paddings": padding,
"dilations": dilation,
"padding_algorithm": padding_algorithm,
"use_cudnn": True,
"ceil_mode": ceil_mode,
Expand Down Expand Up @@ -1141,22 +1153,22 @@ def max_pool2d(
kernel_size: Size2,
stride: Size2 | None = None,
padding: _PaddingSizeMode | Size2 | Size4 = 0,
dilation: Size2 = 1,
return_mask: bool = False,
ceil_mode: bool = False,
data_format: DataLayout2D = 'NCHW',
dilation: Size2 = 1, # 新增 dilation 参数,默认为 1
name: str | None = None,
) -> Tensor:
"""
This API implements max pooling 2d operation with an optional dilation parameter.
This API implements max pooling 2d operation.
See more details in :ref:`api_paddle_nn_MaxPool2d` .

Args:
x (Tensor): The input tensor of pooling operator which is a 4-D tensor with
shape [N, C, H, W]. The format of input tensor is `"NCHW"` or
`"NHWC"`, where `N` is batch size, `C` is the number of channels,
`H` is the height of the feature, and `W` is the width of the
feature. The data type is float32 or float64.
shape [N, C, H, W]. The format of input tensor is `"NCHW"` or
`"NHWC"`, where `N` is batch size, `C` is the number of channels,
`H` is the height of the feature, and `W` is the width of the
feature. The data type if float32 or float64.
kernel_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
it must contain two integers, (kernel_size_Height, kernel_size_Width).
Otherwise, the pool kernel size will be a square of an int.
Expand All @@ -1166,10 +1178,10 @@ def max_pool2d(
padding (string|int|list|tuple): The padding size. Padding could be in one of the following forms.
1. A string in ['valid', 'same'].
2. An int, which means the feature map is zero padded by size of `padding` on every sides.
3. A list[int] or tuple(int) whose length is 2, [pad_height, pad_width] whose value means the padding size of each dimension.
3. A list[int] or tuple(int) whose length is 2, [pad_height, pad_weight] whose value means the padding size of each dimension.
4. A list[int] or tuple(int) whose length is 4. [pad_height_top, pad_height_bottom, pad_width_left, pad_width_right] whose value means the padding size of each side.
5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0).
dilation (int|list|tuple): The dilation rate of the pooling kernel. Default is 1.
The default value is 0.
ceil_mode (bool): when True, will use `ceil` instead of `floor` to compute the output shape
return_mask (bool): Whether to return the max indices along with the outputs. Default False, only support `"NCHW"` data format
data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`.
Expand All @@ -1178,7 +1190,6 @@ def max_pool2d(
name(str|None, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.

Returns:
Tensor: The output tensor of pooling result. The data type is same as input tensor.

Expand All @@ -1188,9 +1199,9 @@ def max_pool2d(
>>> import paddle
>>> import paddle.nn.functional as F

>>> # max pool2d with dilation
>>> # max pool2d
>>> x = paddle.uniform([1, 3, 32, 32], paddle.float32)
>>> out = F.max_pool2d(x, kernel_size=2, stride=2, padding=0, dilation=2)
>>> out = F.max_pool2d(x, kernel_size=2, stride=2, padding=0)
>>> print(out.shape)
[1, 3, 16, 16]
>>> # for return_mask=True
Expand All @@ -1200,8 +1211,7 @@ def max_pool2d(
>>> print(max_indices.shape)
[1, 3, 16, 16]
"""
# 将 dilation 转换为 list/tuple 以方便后续处理
dilation = convert_to_list(dilation, 2, 'pool_dilation')

kernel_size = convert_to_list(kernel_size, 2, 'pool_size')
if stride is None:
stride = kernel_size
Expand All @@ -1220,16 +1230,17 @@ def max_pool2d(
padding, num_dims=2, channel_last=channel_last, ceil_mode=ceil_mode
)

dilation = convert_to_list(dilation, 2, 'pool_dilation')

if data_format == "NHWC" and return_mask:
raise ValueError(
"When setting return_mask to true, data_format must be set to NCHW in API:max_pool2d"
)

# 调用底层的 max_pool2d 操作
if in_dynamic_or_pir_mode():
if return_mask:
output = _C_ops.max_pool2d_with_index(
x, kernel_size, stride, padding, False, False, ceil_mode, dilation
x, kernel_size, stride, padding, dilation, False, False, ceil_mode
)
return output if return_mask else output[0]
else:
Expand All @@ -1238,14 +1249,14 @@ def max_pool2d(
kernel_size,
stride,
padding,
dilation,
ceil_mode,
True,
data_format,
'max',
False,
False,
padding_algorithm,
dilation, # 新增 dilation
)

else:
Expand All @@ -1271,12 +1282,12 @@ def max_pool2d(
"global_pooling": False,
"strides": stride,
"paddings": padding,
"dilations": dilation,
"padding_algorithm": padding_algorithm,
"use_cudnn": True,
"ceil_mode": ceil_mode,
"exclusive": True,
"data_format": data_format,
"dilations": dilation, # 新增 dilations
},
)
return (pool_out, mask)
Expand All @@ -1294,12 +1305,12 @@ def max_pool2d(
"global_pooling": False,
"strides": stride,
"paddings": padding,
"dilations": dilation,
"padding_algorithm": padding_algorithm,
"use_cudnn": True,
"ceil_mode": ceil_mode,
"exclusive": True,
"data_format": data_format,
"dilations": dilation, # 新增 dilations
},
)
return pool_out
Expand Down Expand Up @@ -1492,6 +1503,7 @@ def adaptive_avg_pool1d(
pool_type = 'avg'
_check_input(x, 3)
pool_size = [1, *convert_to_list(output_size, 1, "pool_size")]
dilation = convert_to_list(1, 2, 'pool_dilation')

x = unsqueeze(x, [2])
if in_dynamic_or_pir_mode():
Expand All @@ -1502,6 +1514,7 @@ def adaptive_avg_pool1d(
pool_size,
[1, 1],
[0, 0],
dilation,
False,
True,
"NCHW",
Expand Down Expand Up @@ -1529,6 +1542,7 @@ def adaptive_avg_pool1d(
attrs={
"pooling_type": pool_type,
"ksize": pool_size,
"dilations": dilation,
"adaptive": True,
},
)
Expand Down Expand Up @@ -1631,6 +1645,8 @@ def adaptive_avg_pool2d(
elif _contain_var(output_size):
output_size = _convert_to_tensor_list(output_size)

dilation = convert_to_list(1, 2, 'pool_dilation')

if in_dynamic_or_pir_mode():
if in_dynamic_mode():
x = x._use_gpudnn(False)
Expand All @@ -1639,6 +1655,7 @@ def adaptive_avg_pool2d(
output_size,
[1, 1],
[0, 0],
dilation,
False,
True,
data_format,
Expand Down Expand Up @@ -1666,6 +1683,7 @@ def adaptive_avg_pool2d(
attrs={
"pooling_type": "avg",
"ksize": output_size,
"dilations": dilation,
"adaptive": True,
"data_format": data_format,
},
Expand Down Expand Up @@ -1869,9 +1887,11 @@ def adaptive_max_pool1d(
pool_size = [1, *convert_to_list(output_size, 1, "pool_size")]

x = unsqueeze(x, [2])

dilation = convert_to_list(1, 2, 'pool_dilation')
if in_dynamic_or_pir_mode():
pool_out = _C_ops.max_pool2d_with_index(
x, pool_size, [1, 1], [0, 0], False, True, False
x, pool_size, [1, 1], [0, 0], dilation, False, True, False
)
return (
(squeeze(pool_out[0], [2]), squeeze(pool_out[1], [2]))
Expand Down Expand Up @@ -1901,6 +1921,7 @@ def adaptive_max_pool1d(
attrs={
"pooling_type": 'max',
"ksize": pool_size,
"dilations": dilation,
"adaptive": True,
"ceil_mode": False,
},
Expand Down Expand Up @@ -1970,9 +1991,11 @@ def adaptive_max_pool2d(
output_size[0] = in_h
if output_size[1] is None:
output_size[1] = in_w

dilation = convert_to_list(1, 2, 'pool_dilation')
if in_dynamic_or_pir_mode():
pool_out = _C_ops.max_pool2d_with_index(
x, output_size, [1, 1], [0, 0], False, True, False
x, output_size, [1, 1], [0, 0], dilation, False, True, False
)
return pool_out if return_mask else pool_out[0]
else:
Expand All @@ -1998,6 +2021,7 @@ def adaptive_max_pool2d(
attrs={
"pooling_type": 'max',
"ksize": output_size,
"dilations": dilation,
"adaptive": True,
"ceil_mode": False,
},
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/nn/layer/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,7 @@ class MaxPool2D(Layer):
kernel_size: Size2
stride: Size2 | None
padding: _PaddingSizeMode | Size2 | Size4
dilation: Size2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测没有加,API单测、OP单测需要好好测试一下前方向。

动态图、PIR、动转静CINN模式 都需要测。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的好的

return_mask: bool
ceil_mode: bool
data_format: DataLayout2D
Expand All @@ -795,6 +796,7 @@ def __init__(
kernel_size: Size2,
stride: Size2 | None = None,
padding: _PaddingSizeMode | Size2 | Size4 = 0,
dilation: Size2 = 1,
return_mask: bool = False,
ceil_mode: bool = False,
data_format: DataLayout2D = 'NCHW',
Expand All @@ -804,6 +806,7 @@ def __init__(
self.ksize = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.return_mask = return_mask
self.ceil_mode = ceil_mode
self.data_format = data_format
Expand All @@ -815,14 +818,15 @@ def forward(self, x: Tensor) -> Tensor:
kernel_size=self.ksize,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
return_mask=self.return_mask,
ceil_mode=self.ceil_mode,
data_format=self.data_format,
name=self.name,
)

def extra_repr(self) -> str:
return 'kernel_size={ksize}, stride={stride}, padding={padding}'.format(
return 'kernel_size={ksize}, stride={stride}, padding={padding}, dilation={dilation}'.format(
**self.__dict__
)

Expand Down
Loading