diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index 60840cc60ec5e9..abd8bd4b050ec8 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -135,6 +135,7 @@ 'KthvalueInferMeta', 'MaxPoolWithIndexInferMeta', 'MaxPoolV2InferMeta', + 'MaxPool2DWithDilationsInferMeta', 'MinMaxWithIndexInferMeta', 'MultinomialInferMeta', 'OverlapAddInferMeta', diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/backward_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/backward_infer_sym.cc index f7d7d1101e82b8..12eb06643b8d0a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/backward_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/backward_infer_sym.cc @@ -146,6 +146,12 @@ bool Pool2dGradOpInferSymbolicShape( return true; } +bool MaxPool2dWithDilationsGradOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + SameShapeInfer(infer_context, op->result(0), op->operand_source(0)); + return true; +} + bool BceLossGradOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { SameShapeInfer(infer_context, op->result(0), op->operand_source(0)); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/backward_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/backward_infer_sym.h index 510654c8c51ab2..14ac24d95a2134 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/backward_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/backward_infer_sym.h @@ -24,6 +24,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv2dGrad) OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatmulGrad) OP_DECLARE_INFER_SYMBOLIC_SHAPE(DepthwiseConv2dGrad) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pool2dGrad) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool2dWithDilationsGrad) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLossGrad) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLossGrad_) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 3b24ec9e458e40..9446107bfe1008 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -216,6 +216,132 @@ symbol::ShapeOrDataDimExprs Pool2dRawInferSymbolicShape( return output_shape_or_data; } + +symbol::ShapeOrDataDimExprs MaxPool2dWithDilationsRawInferSymbolicShape( + pir::Operation *op, + const std::vector &kernel_size, + pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + + const auto &x_dims = x_shape_or_data.shape(); + PADDLE_ENFORCE_EQ( + x_dims.size() == 4 || x_dims.size() == 5, + true, + common::errors::InvalidArgument( + "the input of Op(pool) should be 4-D or 5-D Tensor. But " + "received: %u-D Tensor.", + x_dims.size())); + + PADDLE_ENFORCE_EQ(x_dims.size() - kernel_size.size(), + 2U, + common::errors::InvalidArgument( + "the rank of input minus the size of kernel_size " + "must be equal to 2 in Op(pool). " + "But received: the rank of input is %d and the " + "rank of kernel_size is %d.", + x_dims.size(), + kernel_size.size())); + + std::vector strides = [&]() { + std::vector res; + const auto &stride_attr = + op->attributes().at("strides").dyn_cast(); + for (size_t i = 0; i < stride_attr.size(); i++) { + res.emplace_back( + stride_attr.at(i).dyn_cast().data()); + } + return res; + }(); + + PADDLE_ENFORCE_EQ( + kernel_size.size(), + strides.size(), + common::errors::InvalidArgument( + "the rank of kernel_size and strides in Op(pool) must be equal. " + "But received: the rank of kernel_size is %d and the rank of stride " + "is %d.", + kernel_size.size(), + strides.size())); + + const std::string &data_format = + op->attribute("data_format").AsString(); + const bool channel_last = data_format == "NHWC" || data_format == "NDHWC"; + + const auto &data_dims = [&]() -> std::vector { + if (channel_last) { + return std::vector(x_dims.begin() + 1, x_dims.end() - 1); + } else { + return std::vector(x_dims.begin() + 2, x_dims.end()); + } + }(); + + bool global_pooling = + op->attribute("global_pooling").data(); + std::string padding_algorithm = + op->attribute("padding_algorithm").AsString(); + + const auto &real_paddings = [&]() -> std::vector { + std::vector paddings; + const auto &padding_attr = + op->attributes().at("paddings").dyn_cast(); + for (size_t i = 0; i < padding_attr.size(); i++) { + paddings.emplace_back( + padding_attr.at(i).dyn_cast().data()); + } + return GetRealPadding(paddings, + global_pooling, + false, + padding_algorithm, + data_dims, + strides, + kernel_size + + ); + }(); + + const auto &real_kernel_size = [&]() -> std::vector { + if (global_pooling) { + return data_dims; + } + return kernel_size; + }(); + + const auto &output_shape_or_data = [&]() -> symbol::ShapeOrDataDimExprs { + std::vector output_shape; + bool ceil_mode = op->attribute("ceil_mode").data(); + for (size_t i = 0; i < data_dims.size(); ++i) { + symbol::DimExpr stride_dimexpr{strides[i]}; + symbol::DimExpr one_dimexpr{1}; + if (!ceil_mode) { + output_shape.emplace_back((data_dims[i] - real_kernel_size[i] + + real_paddings[2 * i] + + real_paddings[2 * i + 1]) / + stride_dimexpr + + one_dimexpr); + } else { + output_shape.emplace_back( + (data_dims[i] - real_kernel_size[i] + real_paddings[2 * i] + + real_paddings[2 * i + 1] + stride_dimexpr - one_dimexpr) / + stride_dimexpr + + one_dimexpr); + } + } + + // output_N = input_N + output_shape.insert(output_shape.begin(), x_dims[0]); + // output_C = input_C + if (channel_last) { + output_shape.push_back(x_dims[x_dims.size() - 1]); + } else { + output_shape.insert(output_shape.begin() + 1, x_dims[1]); + } + return symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(output_shape)}; + }(); + + return output_shape_or_data; +} } // namespace namespace paddle::dialect { @@ -2947,6 +3073,19 @@ bool Pool2dOpInferSymbolicShape(pir::Operation *op, return true; } +bool MaxPool2dWithDilationsOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &kernel_size_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + const auto &kernel_size = + paddle::dialect::details::GetExprVecFromData(kernel_size_shape_or_data); + infer_context->SetShapeOrDataForValue( + op->result(0), + MaxPool2dWithDilationsRawInferSymbolicShape( + op, kernel_size, infer_context)); + return true; +} + bool Pool3dOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { std::vector kernel_size_ = diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index a6b97105bcdc02..a2c0e05664763a 100755 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -125,6 +125,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(PartialSum) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad3d) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pool2d) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool2dWithDilations) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pool3d) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pool) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 85f45eee29b432..c05966dd4ca561 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1134,6 +1134,20 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, dx->share_meta(x); } +void MaxPool2dWithIndexGradInferMeta(const MetaTensor& x, + const MetaTensor& mask, + const MetaTensor& dout, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + bool global_pooling, + bool adaptive, + bool ceil_mode, + MetaTensor* dx) { + dx->share_meta(x); +} + void MedianGradInferMeta(const MetaTensor& x, const MetaTensor& median_data, const MetaTensor& median_index, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 9c330fee9be532..d8fd84bd7488f7 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -454,6 +454,19 @@ PADDLE_API void MaxPoolWithIndexGradInferMeta( bool ceil_mode, MetaTensor* dx); +PADDLE_API void MaxPool2dWithIndexGradInferMeta( + const MetaTensor& x, + const MetaTensor& mask, + const MetaTensor& dout, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + bool global_pooling, + bool adaptive, + bool ceil_mode, + MetaTensor* dx); + PADDLE_API void MedianGradInferMeta(const MetaTensor& x, const MetaTensor& median_data, const MetaTensor& median_index, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index ac171094731d2f..03a4e17779c222 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2770,6 +2770,29 @@ void MaxOutInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void MaxPool2dWithIndexInferMeta(const MetaTensor& x, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + bool global_pooling, + bool adaptive, + bool ceil_mode, + MetaTensor* out, + MetaTensor* mask, + MetaConfig config) { + MaxPoolWithIndexInferMeta(x, + kernel_size, + strides, + paddings, + global_pooling, + adaptive, + ceil_mode, + out, + mask, + config); +} + void MaxPoolWithIndexInferMeta(const MetaTensor& x, const std::vector& kernel_size, const std::vector& strides, @@ -3794,6 +3817,32 @@ void Pool2DInferMeta(const MetaTensor& x, } } +void MaxPool2DWithDilationsInferMeta(const MetaTensor& x, + const IntArray& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + bool ceil_mode, + const std::string& data_format, + bool global_pooling, + const std::string& padding_algorithm, + MetaTensor* out, + MetaConfig config) { + Pool2DInferMeta(x, + kernel_size, + strides, + paddings, + ceil_mode, + true, + data_format, + "max", + global_pooling, + false, + padding_algorithm, + out, + config); +} + void PSendInferMeta(const MetaTensor& x, int peer) { LOG(INFO) << "SendBaseInferMeta begin"; PADDLE_ENFORCE_GE( diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 4e50607263950b..d5183f6fa10ba9 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -482,6 +482,18 @@ PADDLE_API void MaxPoolWithIndexInferMeta(const MetaTensor& x, MetaTensor* mask, MetaConfig config = MetaConfig()); +PADDLE_API void MaxPool2dWithIndexInferMeta(const MetaTensor& x, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + bool global_pooling, + bool adaptive, + bool ceil_mode, + MetaTensor* out, + MetaTensor* mask, + MetaConfig config = MetaConfig()); + PADDLE_API void MaxPoolV2InferMeta(const MetaTensor& x, const std::vector& kernel_size, const std::vector& strides, @@ -625,6 +637,19 @@ PADDLE_API void Pool2DInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +PADDLE_API void MaxPool2DWithDilationsInferMeta( + const MetaTensor& x, + const IntArray& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + bool ceil_mode, + const std::string& data_format, + bool global_pooling, + const std::string& padding_algorithm, + MetaTensor* out, + MetaConfig config = MetaConfig()); + PADDLE_API void PSendInferMeta(const MetaTensor& x, int peer); PADDLE_API void PSendArrayInferMeta(const MetaTensor& x, int peer); diff --git a/paddle/phi/kernels/cpu/pool_grad_kernel.cc b/paddle/phi/kernels/cpu/pool_grad_kernel.cc index 17b1c63a95bb70..539c455cdaf6eb 100644 --- a/paddle/phi/kernels/cpu/pool_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/pool_grad_kernel.cc @@ -24,6 +24,13 @@ PD_REGISTER_KERNEL(pool2d_grad, float, double, phi::float16) {} +PD_REGISTER_KERNEL(max_pool2d_with_dilations_grad, + CPU, + ALL_LAYOUT, + phi::MaxPool2DWithDilationsGradKernel, + float, + double, + phi::float16) {} PD_REGISTER_KERNEL( lp_pool2d_grad, CPU, ALL_LAYOUT, phi::LPPool2dGradKernel, float, double) {} PD_REGISTER_KERNEL(pool2d_double_grad, diff --git a/paddle/phi/kernels/cpu/pool_kernel.cc b/paddle/phi/kernels/cpu/pool_kernel.cc index 02a867f70060ad..77fc08179a431f 100644 --- a/paddle/phi/kernels/cpu/pool_kernel.cc +++ b/paddle/phi/kernels/cpu/pool_kernel.cc @@ -19,6 +19,13 @@ PD_REGISTER_KERNEL( pool2d, CPU, ALL_LAYOUT, phi::Pool2dKernel, float, double, phi::float16) {} +PD_REGISTER_KERNEL(max_pool2d_with_dilations, + CPU, + ALL_LAYOUT, + phi::MaxPool2DWithDilationsKernel, + float, + double, + phi::float16) {} PD_REGISTER_KERNEL( lp_pool2d, CPU, ALL_LAYOUT, phi::LPPool2dKernel, float, double) {} PD_REGISTER_KERNEL(max_pool2d_with_index, diff --git a/paddle/phi/kernels/funcs/pooling.cc b/paddle/phi/kernels/funcs/pooling.cc index 1ba7e69d568d3f..36028ece87b08a 100644 --- a/paddle/phi/kernels/funcs/pooling.cc +++ b/paddle/phi/kernels/funcs/pooling.cc @@ -172,6 +172,315 @@ class Pool2dFunctor { } }; +template +class MaxPool2DWithDilationsFunctor { + public: + void operator()(const CPUContext& context, + const DenseTensor& input, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + const std::string data_format, + DenseTensor* output) { + bool channel_last = (data_format == "NHWC"); + + const int64_t batch_size = input.dims()[0]; + const int64_t input_channels = + channel_last ? input.dims()[3] : input.dims()[1]; + const int64_t input_height = + channel_last ? input.dims()[1] : input.dims()[2]; + const int64_t input_width = + channel_last ? input.dims()[2] : input.dims()[3]; + + const int64_t output_channels = + channel_last ? output->dims()[3] : output->dims()[1]; + const int64_t output_height = + channel_last ? output->dims()[1] : output->dims()[2]; + const int64_t output_width = + channel_last ? output->dims()[2] : output->dims()[3]; + + const int64_t ksize_height = ksize[0]; + const int64_t ksize_width = ksize[1]; + + const int64_t stride_height = strides[0]; + const int64_t stride_width = strides[1]; + + const int64_t padding_height = paddings[0]; + const int64_t padding_width = paddings[1]; + + const int64_t dilation_height = dilations[0]; + const int64_t dilation_width = dilations[1]; + + const T* input_data = input.data(); + T* output_data = context.template Alloc(output); + + int64_t hstart = 0, hend = 1; + int64_t wstart = 0, wend = 1; + if (!channel_last) { + const int64_t input_stride = input_height * input_width; + const int64_t output_stride = output_height * output_width; + for (int64_t i = 0; i < batch_size; i++) { + for (int64_t c = 0; c < output_channels; ++c) { + for (int64_t ph = 0; ph < output_height; ++ph) { + for (int64_t pw = 0; pw < output_width; ++pw) { + hstart = ph * stride_height - padding_height; + hend = hstart + (ksize_height - 1) * dilation_height + 1; + + hstart = (hstart < static_cast(0)) + ? hstart + ((-hstart + dilation_height - 1) / + dilation_height) * + dilation_height + : hstart; + + hend = + (hend > input_height) + ? input_height - ((hend - input_height) % dilation_height) + : hend; + + wstart = pw * stride_width - padding_width; + wend = wstart + (ksize_width - 1) * dilation_width + 1; + + wstart = (wstart < static_cast(0)) + ? wstart + ((-wstart + dilation_width - 1) / + dilation_width) * + dilation_width + : wstart; + + wend = (wend > input_width) + ? input_width - ((wend - input_width) % dilation_width) + : wend; + + T ele = static_cast(-FLT_MAX); + for (int64_t h = hstart; h < hend; h += dilation_height) { + for (int64_t w = wstart; w < wend; w += dilation_width) { + ele = input_data[h * input_width + w] > ele + ? input_data[h * input_width + w] + : ele; + } + } + output_data[ph * output_width + pw] = ele; + } + } + input_data += input_stride; + output_data += output_stride; + } + } + } else { + const int64_t input_stride = input_height * input_width * input_channels; + const int64_t output_stride = + output_height * output_width * output_channels; + for (int64_t i = 0; i < batch_size; i++) { + for (int64_t c = 0; c < output_channels; ++c) { + for (int64_t ph = 0; ph < output_height; ++ph) { + for (int64_t pw = 0; pw < output_width; ++pw) { + hstart = ph * stride_height - padding_height; + hend = hstart + (ksize_height - 1) * dilation_height + 1; + + hstart = (hstart < static_cast(0)) + ? hstart + ((-hstart + dilation_height - 1) / + dilation_height) * + dilation_height + : hstart; + + hend = + (hend > input_height) + ? input_height - ((hend - input_height) % dilation_height) + : hend; + + wstart = pw * stride_width - padding_width; + wend = wstart + (ksize_width - 1) * dilation_width + 1; + + wstart = (wstart < static_cast(0)) + ? wstart + ((-wstart + dilation_width - 1) / + dilation_width) * + dilation_width + : wstart; + + wend = (wend > input_width) + ? input_width - ((wend - input_width) % dilation_width) + : wend; + T ele = static_cast(-FLT_MAX); + for (int64_t h = hstart; h < hend; h += dilation_height) { + for (int64_t w = wstart; w < wend; w += dilation_width) { + ele = input_data[h * input_width * input_channels + + w * input_channels + c] > ele + ? input_data[h * input_width * input_channels + + w * input_channels + c] + : ele; + } + } + output_data[ph * output_width * output_channels + + pw * output_channels + c] = ele; + } + } + } + input_data += input_stride; + output_data += output_stride; + } + } + } +}; + +template +class MaxPool2DWithDilationsGradFunctor { + public: + void operator()(const CPUContext& context, + const DenseTensor& input, + const DenseTensor& output, + const DenseTensor& output_grad, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + const std::string data_format, + DenseTensor* input_grad) { + bool channel_last = (data_format == "NHWC"); + + const int64_t batch_size = input.dims()[0]; + + const int64_t input_channels = + channel_last ? input.dims()[3] : input.dims()[1]; + const int64_t input_height = + channel_last ? input.dims()[1] : input.dims()[2]; + const int64_t input_width = + channel_last ? input.dims()[2] : input.dims()[3]; + + const int64_t output_channels = + channel_last ? output.dims()[3] : output.dims()[1]; + const int64_t output_height = + channel_last ? output.dims()[1] : output.dims()[2]; + const int64_t output_width = + channel_last ? output.dims()[2] : output.dims()[3]; + + const int64_t ksize_height = ksize[0]; + const int64_t ksize_width = ksize[1]; + + const int64_t stride_height = strides[0]; + const int64_t stride_width = strides[1]; + + const int64_t padding_height = paddings[0]; + const int64_t padding_width = paddings[1]; + + const int64_t dilation_height = dilations[0]; + const int64_t dilation_width = dilations[1]; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = context.template Alloc(input_grad); + + if (!channel_last) { + const int64_t input_stride = input_height * input_width; + const int64_t output_stride = output_height * output_width; + for (int64_t i = 0; i < batch_size; i++) { + for (int64_t c = 0; c < output_channels; ++c) { + for (int64_t ph = 0; ph < output_height; ++ph) { + int64_t hstart = ph * stride_height - padding_height; + int64_t hend = hstart + (ksize_height - 1) * dilation_height + 1; + hstart = (hstart < static_cast(0)) + ? hstart + ((-hstart + dilation_height - 1) / + dilation_height) * + dilation_height + : hstart; + + hend = + (hend > input_height) + ? input_height - ((hend - input_height) % dilation_height) + : hend; + for (int64_t pw = 0; pw < output_width; ++pw) { + int64_t wstart = pw * stride_width - padding_width; + int64_t wend = wstart + (ksize_width - 1) * dilation_width + 1; + wstart = (wstart < static_cast(0)) + ? wstart + ((-wstart + dilation_width - 1) / + dilation_width) * + dilation_width + : wstart; + + wend = (wend > input_width) + ? input_width - ((wend - input_width) % dilation_width) + : wend; + + bool stop = false; + for (int64_t h = hstart; h < hend && !stop; + h += dilation_height) { + for (int64_t w = wstart; w < wend && !stop; + w += dilation_width) { + int64_t input_idx = h * input_width + w; + int64_t output_idx = ph * output_width + pw; + if (input_data[input_idx] == output_data[output_idx]) { + input_grad_data[input_idx] += output_grad_data[output_idx]; + stop = true; + } + } + } + } + } + input_data += input_stride; + output_data += output_stride; + input_grad_data += input_stride; + output_grad_data += output_stride; + } + } + } else { + const int64_t input_stride = input_height * input_width * input_channels; + const int64_t output_stride = + output_height * output_width * output_channels; + for (int64_t i = 0; i < batch_size; i++) { + for (int64_t c = 0; c < output_channels; ++c) { + for (int64_t ph = 0; ph < output_height; ++ph) { + int64_t hstart = ph * stride_height - padding_height; + int64_t hend = hstart + (ksize_height - 1) * dilation_height + 1; + hstart = (hstart < static_cast(0)) + ? hstart + ((-hstart + dilation_height - 1) / + dilation_height) * + dilation_height + : hstart; + + hend = + (hend > input_height) + ? input_height - ((hend - input_height) % dilation_height) + : hend; + for (int64_t pw = 0; pw < output_width; ++pw) { + int64_t wstart = pw * stride_width - padding_width; + int64_t wend = wstart + (ksize_width - 1) * dilation_width + 1; + wstart = (wstart < static_cast(0)) + ? wstart + ((-wstart + dilation_width - 1) / + dilation_width) * + dilation_width + : wstart; + + wend = (wend > input_width) + ? input_width - ((wend - input_width) % dilation_width) + : wend; + + bool stop = false; + for (int64_t h = hstart; h < hend && !stop; + h += dilation_height) { + for (int64_t w = wstart; w < wend && !stop; + w += dilation_width) { + int64_t input_idx = + h * input_width * input_channels + w * input_channels + c; + int64_t output_idx = ph * output_width * output_channels + + pw * output_channels + c; + if (input_data[input_idx] == output_data[output_idx]) { + input_grad_data[input_idx] += output_grad_data[output_idx]; + stop = true; + } + } + } + } + } + } + input_data += input_stride; + output_data += output_stride; + input_grad_data += input_stride; + output_grad_data += output_stride; + } + } + } +}; + /* * tensors are in NCHW or NHWC format. * Ksize, strides are two elements. These two elements represent height @@ -463,7 +772,6 @@ class MaxPool2dGradFunctor { }; template class MaxPool2dGradFunctor; template class MaxPool2dGradFunctor; - template class MaxPool2dGradFunctor; template class Pool2dFunctor, float>; @@ -472,6 +780,14 @@ template class Pool2dFunctor, float>; template class Pool2dGradFunctor, float>; template class Pool2dGradFunctor, float>; template class Pool2dGradFunctor, float>; + +template class MaxPool2DWithDilationsFunctor; +template class MaxPool2DWithDilationsFunctor; +template class MaxPool2DWithDilationsFunctor; +template class MaxPool2DWithDilationsGradFunctor; +template class MaxPool2DWithDilationsGradFunctor; +template class MaxPool2DWithDilationsGradFunctor; + template class Pool2dFunctor, double>; template class Pool2dFunctor, double>; template class Pool2dFunctor, double>; @@ -1184,6 +1500,90 @@ class MaxPool2dWithIndexFunctor { } }; +template +class MaxPool2dWithDilationsAndIndexFunctor { + public: + void operator()(const CPUContext& context, + const DenseTensor& input, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + DenseTensor* output, + DenseTensor* mask) { + const int64_t batch_size = input.dims()[0]; + const int64_t input_height = input.dims()[2]; + const int64_t input_width = input.dims()[3]; + const int64_t output_channels = output->dims()[1]; + const int64_t output_height = output->dims()[2]; + const int64_t output_width = output->dims()[3]; + const int64_t ksize_height = ksize[0]; + const int64_t ksize_width = ksize[1]; + const int64_t stride_height = strides[0]; + const int64_t stride_width = strides[1]; + const int64_t padding_height = paddings[0]; + const int64_t padding_width = paddings[1]; + const int64_t dilation_height = dilations[0]; + const int64_t dilation_width = dilations[1]; + const int64_t input_stride = input_height * input_width; + const int64_t output_stride = output_height * output_width; + + const T1* input_data = input.data(); + T1* output_data = context.template Alloc(output); + T2* mask_data = context.template Alloc(mask); + + int64_t hstart = 0, hend = 0; + int64_t wstart = 0, wend = 0; + for (int64_t i = 0; i < batch_size; i++) { + for (int64_t c = 0; c < output_channels; ++c) { + for (int64_t ph = 0; ph < output_height; ++ph) { + hstart = ph * stride_height - padding_height; + hend = hstart + (ksize_height - 1) * dilation_height + 1; + hstart = (hstart < static_cast(0)) + ? hstart + ((-hstart + dilation_height - 1) / + dilation_height) * + dilation_height + : hstart; + + hend = (hend > input_height) + ? input_height - ((hend - input_height) % dilation_height) + : hend; + for (int64_t pw = 0; pw < output_width; ++pw) { + wstart = pw * stride_width - padding_width; + wend = wstart + (ksize_width - 1) * dilation_width + 1; + wstart = (wstart < static_cast(0)) + ? wstart + ((-wstart + dilation_width - 1) / + dilation_width) * + dilation_width + : wstart; + + wend = (wend > input_width) + ? input_width - ((wend - input_width) % dilation_width) + : wend; + + T1 ele = static_cast(-FLT_MAX); + int64_t index = -1; + for (int64_t h = hstart; h < hend; h += dilation_height) { + for (int64_t w = wstart; w < wend; w += dilation_width) { + if (ele < input_data[h * input_width + w]) { + ele = input_data[h * input_width + w]; + index = h * input_width + w; + } + } + } + output_data[ph * output_width + pw] = ele; + mask_data[ph * output_width + pw] = index; + } + } + // offset + input_data += input_stride; + output_data += output_stride; + mask_data += output_stride; + } + } + } +}; + /* * All tensors are in NCHW format. * Ksize, strides, paddings are two elements. These two elements represent @@ -1232,10 +1632,61 @@ class MaxPool2dWithIndexGradFunctor { } }; +template +class MaxPool2dWithDilationsAndIndexGradFunctor { + public: + void operator()(const CPUContext& context, + const DenseTensor& output_grad, + const DenseTensor& mask, + const std::vector& ksize UNUSED, + const std::vector& strides UNUSED, + const std::vector& paddings UNUSED, + const std::vector& dilations UNUSED, + DenseTensor* input_grad) { + const int64_t batch_size = input_grad->dims()[0]; + const int64_t input_height = input_grad->dims()[2]; + const int64_t input_width = input_grad->dims()[3]; + const int64_t output_channels = output_grad.dims()[1]; + const int64_t output_height = output_grad.dims()[2]; + const int64_t output_width = output_grad.dims()[3]; + const int64_t input_stride = input_height * input_width; + const int64_t output_stride = output_height * output_width; + + const T2* mask_data = mask.data(); + const T1* output_grad_data = output_grad.data(); + T1* input_grad_data = context.template Alloc(input_grad); + + for (int64_t n = 0; n < batch_size; ++n) { + for (int64_t c = 0; c < output_channels; ++c) { + for (int64_t ph = 0; ph < output_height; ++ph) { + for (int64_t pw = 0; pw < output_width; ++pw) { + const int64_t output_idx = ph * output_width + pw; + const int64_t input_idx = + static_cast(mask_data[output_idx]); + input_grad_data[input_idx] += output_grad_data[output_idx]; + } + } + // offset + input_grad_data += input_stride; + output_grad_data += output_stride; + mask_data += output_stride; + } + } + } +}; + template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithDilationsAndIndexFunctor; template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithDilationsAndIndexGradFunctor; template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithDilationsAndIndexFunctor; template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithDilationsAndIndexGradFunctor; /* * All tensors are in NCDHW format. diff --git a/paddle/phi/kernels/funcs/pooling.cu b/paddle/phi/kernels/funcs/pooling.cu index 49900d38efdc09..aa0f90ed496552 100644 --- a/paddle/phi/kernels/funcs/pooling.cu +++ b/paddle/phi/kernels/funcs/pooling.cu @@ -102,6 +102,13 @@ static __device__ inline int p_start(int size, return (size + pad < kernel) ? 0 : (size + pad - kernel) / stride + 1; } +static __device__ inline int p_start_with_dilations( + int size, int pad, int dilation, int kernel, int stride) { + return (size + pad < kernel * dilation) + ? 0 + : (size + pad - kernel * dilation) / stride + 1; +} + static __device__ inline int p_end(int size, int pad, int pooled_size, @@ -219,6 +226,84 @@ __global__ void KernelPool2D(const IndexT nthreads, } } +template +__global__ void KernelMaxPool2DWithDilations( + const IndexT nthreads, + const T* input_data, + const IndexT channels, + const IndexT input_height, + const IndexT input_width, + const IndexT output_height, + const IndexT output_width, + const IndexT ksize_height, + const IndexT ksize_width, + const IndexT stride_height, + const IndexT stride_width, + const IndexT padding_height, + const IndexT padding_width, + const IndexT dilation_height, + const IndexT dilation_width, + FastDivModForPooling divmods, + T* output_data, + bool channel_last = false) { + const IndexT start_index = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const IndexT step = static_cast(blockDim.x) * gridDim.x; + for (IndexT index = start_index; index < nthreads; index += step) { + IndexT hstart, hend, wstart, wend; + IndexT w_offset, h_offset, c_offset, input_offset; + OffsetPreparationFor4Dimension, IndexT>( + index, + channel_last, + divmods, + 0, + 0, + input_width, + input_height, + &w_offset, + &h_offset, + &c_offset, + &input_offset); + input_data += input_offset; + + hstart = h_offset * stride_height - padding_height; + hend = hstart + (ksize_height - 1) * dilation_height + 1; + + hstart = + (hstart < static_cast(0)) + ? hstart + ((-hstart + dilation_height - 1) / dilation_height) * + dilation_height + : hstart; + + hend = (hend > input_height) + ? input_height - ((hend - input_height) % dilation_height) + : hend; + + wstart = w_offset * stride_width - padding_width; + wend = wstart + (ksize_width - 1) * dilation_width + 1; + + wstart = (wstart < static_cast(0)) + ? wstart + ((-wstart + dilation_width - 1) / dilation_width) * + dilation_width + : wstart; + + wend = (wend > input_width) + ? input_width - ((wend - input_width) % dilation_width) + : wend; + + T ele = static_cast(-FLT_MAX); + for (IndexT h = hstart; h < hend; h += dilation_height) { + for (IndexT w = wstart; w < wend; w += dilation_width) { + auto input_idx = channel_last + ? (h * input_width + w) * channels + c_offset + : h * input_width + w; + ele = input_data[input_idx] > ele ? input_data[input_idx] : ele; + } + } + output_data[index] = ele; + } +} + template __global__ void AdaptiveKernelPool2D(const IndexT nthreads, const T* input_data, @@ -541,6 +626,154 @@ __global__ void KernelMaxPool2DGradCompatible( } } +template +__global__ void KernelMaxPool2DWithDilationsGrad( + const IndexT nthreads, + const T* input_data, + const T* output_data, + const T* output_grad, + const IndexT channels, + const IndexT input_height, + const IndexT input_width, + const IndexT output_height, + const IndexT output_width, + const IndexT ksize_height, + const IndexT ksize_width, + const IndexT stride_height, + const IndexT stride_width, + const IndexT padding_height, + const IndexT padding_width, + const IndexT dilation_height, + const IndexT dilation_width, + T* input_grad, + FastDivModForPooling divmods, + bool channel_last = false) { + const IndexT start_index = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const IndexT step = static_cast(blockDim.x) * gridDim.x; + for (IndexT index = start_index; index < nthreads; index += step) { + IndexT w_offset, h_offset, c_offset, input_offset; + OffsetPreparationFor4Dimension, IndexT>( + index, + channel_last, + divmods, + 0, + 0, + input_width, + input_height, + &w_offset, + &h_offset, + &c_offset, + &input_offset); + input_data += input_offset; + input_grad += input_offset; + IndexT hstart, hend, wstart, wend; + + hstart = h_offset * stride_height - padding_height; + hend = hstart + (ksize_height - 1) * dilation_height + 1; + + hstart = + (hstart < static_cast(0)) + ? hstart + ((-hstart + dilation_height - 1) / dilation_height) * + dilation_height + : hstart; + + hend = (hend > input_height) + ? input_height - ((hend - input_height) % dilation_height) + : hend; + + wstart = w_offset * stride_width - padding_width; + wend = wstart + (ksize_width - 1) * dilation_width + 1; + + wstart = (wstart < static_cast(0)) + ? wstart + ((-wstart + dilation_width - 1) / dilation_width) * + dilation_width + : wstart; + + wend = (wend > input_width) + ? input_width - ((wend - input_width) % dilation_width) + : wend; + + T ele = output_data[index]; + IndexT maxIndex = -1; + bool stop = false; + for (IndexT h = hstart; h < hend && !stop; h += dilation_height) { + for (IndexT w = wstart; w < wend && !stop; w += dilation_width) { + IndexT input_data_idx = + channel_last ? (h * input_width + w) * channels + c_offset + : h * input_width + w; + if (ele == input_data[input_data_idx]) { + maxIndex = input_data_idx; + stop = true; + } + } + } + + if (maxIndex != -1) { + // atomic add + phi::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]); + } + } +} + +template +__global__ void KernelMaxPool2DWithDilationsGradCompatible( + const T* input_data, + const T* output_data, + const T* output_grad, + const IndexT batch_size, + const IndexT channels, + const IndexT input_height, + const IndexT input_width, + const IndexT output_height, + const IndexT output_width, + const IndexT ksize_height, + const IndexT ksize_width, + const IndexT stride_height, + const IndexT stride_width, + const IndexT padding_height, + const IndexT padding_width, + const IndexT dilation_height, + const IndexT dilation_width, + T* input_grad, + FastDivModForPooling divmods, + bool channel_last = false) { + using MPType = typename phi::dtype::MPTypeTrait::Type; + + CUDA_KERNEL_LOOP(index, input_height * input_width) { + IndexT h = index / input_width; + IndexT w = index - h * input_width; + IndexT phstart = p_start_with_dilations( + h, padding_height, dilation_height, ksize_height, stride_height); + IndexT phend = p_end(h, padding_height, output_height, stride_height); + IndexT pwstart = p_start_with_dilations( + w, padding_width, dilation_width, ksize_width, stride_width); + IndexT pwend = p_end(w, padding_width, output_width, stride_width); + T input_data_value = input_data[h * input_width + w]; + for (IndexT n = blockIdx.y; n < batch_size; n += gridDim.y) { + for (IndexT c = blockIdx.z; c < channels; c += gridDim.z) { + MPType gradient = static_cast(0.0f); + IndexT offset = (n * channels + c) * output_height * output_width; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + IndexT hstart = ph * stride_height - padding_height; + IndexT wstart = pw * stride_width - padding_width; + T output_data_value = output_data[ph * output_width + pw + offset]; + if (((h - hstart) % dilation_height == 0) && + ((w - wstart) % dilation_width == 0) && + (output_data_value == input_data_value)) { + gradient += static_cast( + output_grad[ph * output_width + pw + offset]); + } + } + } + input_grad[(n * channels + c) * input_height * input_width + index] = + static_cast(gradient); + } + } + } +} + template void Pool2dDirectCUDAFunctor::operator()( const T* input, @@ -802,31 +1035,21 @@ class Pool2dFunctor { } } }; -/* - * Tensors are in NCHW or NHWC format. - * Ksize, strides are two elements. These two elements represent height - * and width, respectively. - * Paddings are four elements. These four elements represent height_up, - * height_down, width_left and width_right, respectively. - */ -template -class Pool2dGradFunctor { + +template +class MaxPool2DWithDilationsFunctor { public: void operator()(const phi::GPUContext& dev_ctx, const DenseTensor& input, - const DenseTensor& output, - const DenseTensor& output_grad, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, const std::string data_format, - bool exclusive, - bool adaptive, - DenseTensor* input_grad, - PoolProcess pool_process) { + DenseTensor* output) { bool channel_last = (data_format == "NHWC"); - const int64_t batch_size = input.dims()[0]; + const int64_t input_channels = channel_last ? input.dims()[3] : input.dims()[1]; const int64_t input_height = @@ -835,11 +1058,11 @@ class Pool2dGradFunctor { channel_last ? input.dims()[2] : input.dims()[3]; const int64_t output_channels = - channel_last ? output.dims()[3] : output.dims()[1]; + channel_last ? output->dims()[3] : output->dims()[1]; const int64_t output_height = - channel_last ? output.dims()[1] : output.dims()[2]; + channel_last ? output->dims()[1] : output->dims()[2]; const int64_t output_width = - channel_last ? output.dims()[2] : output.dims()[3]; + channel_last ? output->dims()[2] : output->dims()[3]; const int64_t ksize_height = ksize[0]; const int64_t ksize_width = ksize[1]; @@ -850,66 +1073,178 @@ class Pool2dGradFunctor { const int64_t padding_height = paddings[0]; const int64_t padding_width = paddings[1]; + const int64_t dilation_height = dilations[0]; + const int64_t dilation_width = dilations[1]; + const T* input_data = input.data(); - const T* output_data = output.data(); - const T* output_grad_data = output_grad.data(); - T* input_grad_data = dev_ctx.template Alloc(input_grad); + T* output_data = dev_ctx.template Alloc(output); - int64_t nthreads = batch_size * input_channels * input_height * input_width; - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nthreads); - if (input.numel() <= std::numeric_limits::max() && - output.numel() <= std::numeric_limits::max()) { - auto pool_divmods = FastDivModForPoolingWithMoreStaff(input_channels, - input_width, - input_height, - ksize_width, - ksize_height, - stride_width, - stride_height); - KernelPool2DGrad - <<>>(nthreads, - input_data, - output_data, - output_grad_data, - output_width, - output_height, - input_width, - input_height, - ksize_width, - ksize_height, - stride_width, - stride_height, - padding_width, - padding_height, - pool_divmods, - pool_process, - exclusive, - adaptive, - input_grad_data, - channel_last); - } else { - auto pool_divmods = - FastDivModForPoolingWithMoreStaff(input_channels, - input_width, - input_height, - ksize_width, - ksize_height, - stride_width, - stride_height); - KernelPool2DGrad - <<>>(nthreads, - input_data, - output_data, - output_grad_data, - output_width, - output_height, - input_width, + std::array max_grid_dim = dev_ctx.GetCUDAMaxGridDimSize(); + int64_t nthreads = + batch_size * output_channels * output_height * output_width; + int thread_num = 1024; +#ifdef WITH_NV_JETSON + backends::gpu::ChangeThreadNum(dev_ctx, &thread_num); +#endif + int64_t blocks = (nthreads + thread_num - 1) / thread_num; + dim3 threads(thread_num, 1); + dim3 grid(blocks, 1); + if (input.numel() <= std::numeric_limits::max()) { + auto pool_divmods = FastDivModForPooling( + input_channels, output_width, output_height); + KernelMaxPool2DWithDilations + <<>>(nthreads, + input_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + dilation_height, + dilation_width, + pool_divmods, + output_data, + channel_last); + } else { + auto pool_divmods = FastDivModForPooling( + input_channels, output_width, output_height); + KernelMaxPool2DWithDilations + <<>>(nthreads, + input_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + dilation_height, + dilation_width, + pool_divmods, + output_data, + channel_last); + } + } +}; + +/* + * Tensors are in NCHW or NHWC format. + * Ksize, strides are two elements. These two elements represent height + * and width, respectively. + * Paddings are four elements. These four elements represent height_up, + * height_down, width_left and width_right, respectively. + */ +template +class Pool2dGradFunctor { + public: + void operator()(const phi::GPUContext& dev_ctx, + const DenseTensor& input, + const DenseTensor& output, + const DenseTensor& output_grad, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + const std::string data_format, + bool exclusive, + bool adaptive, + DenseTensor* input_grad, + PoolProcess pool_process) { + bool channel_last = (data_format == "NHWC"); + + const int64_t batch_size = input.dims()[0]; + const int64_t input_channels = + channel_last ? input.dims()[3] : input.dims()[1]; + const int64_t input_height = + channel_last ? input.dims()[1] : input.dims()[2]; + const int64_t input_width = + channel_last ? input.dims()[2] : input.dims()[3]; + + const int64_t output_channels = + channel_last ? output.dims()[3] : output.dims()[1]; + const int64_t output_height = + channel_last ? output.dims()[1] : output.dims()[2]; + const int64_t output_width = + channel_last ? output.dims()[2] : output.dims()[3]; + + const int64_t ksize_height = ksize[0]; + const int64_t ksize_width = ksize[1]; + + const int64_t stride_height = strides[0]; + const int64_t stride_width = strides[1]; + + const int64_t padding_height = paddings[0]; + const int64_t padding_width = paddings[1]; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = dev_ctx.template Alloc(input_grad); + + int64_t nthreads = batch_size * input_channels * input_height * input_width; + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nthreads); + if (input.numel() <= std::numeric_limits::max() && + output.numel() <= std::numeric_limits::max()) { + auto pool_divmods = FastDivModForPoolingWithMoreStaff(input_channels, + input_width, + input_height, + ksize_width, + ksize_height, + stride_width, + stride_height); + KernelPool2DGrad + <<>>(nthreads, + input_data, + output_data, + output_grad_data, + output_width, + output_height, + input_width, + input_height, + ksize_width, + ksize_height, + stride_width, + stride_height, + padding_width, + padding_height, + pool_divmods, + pool_process, + exclusive, + adaptive, + input_grad_data, + channel_last); + } else { + auto pool_divmods = + FastDivModForPoolingWithMoreStaff(input_channels, + input_width, + input_height, + ksize_width, + ksize_height, + stride_width, + stride_height); + KernelPool2DGrad + <<>>(nthreads, + input_data, + output_data, + output_grad_data, + output_width, + output_height, + input_width, input_height, ksize_width, ksize_height, @@ -1092,6 +1427,176 @@ class MaxPool2dGradFunctor { } }; +template +class MaxPool2DWithDilationsGradFunctor { + public: + void operator()(const phi::GPUContext& dev_ctx, + const DenseTensor& input, + const DenseTensor& output, + const DenseTensor& output_grad, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + const std::string data_format, + DenseTensor* input_grad) { + static const int kBlockThreads = 1024; + + bool channel_last = (data_format == "NHWC"); + + const int64_t batch_size = input.dims()[0]; + + const int64_t input_channels = + channel_last ? input.dims()[3] : input.dims()[1]; + const int64_t input_height = + channel_last ? input.dims()[1] : input.dims()[2]; + const int64_t input_width = + channel_last ? input.dims()[2] : input.dims()[3]; + + const int64_t output_channels = + channel_last ? output.dims()[3] : output.dims()[1]; + const int64_t output_height = + channel_last ? output.dims()[1] : output.dims()[2]; + const int64_t output_width = + channel_last ? output.dims()[2] : output.dims()[3]; + + const int64_t ksize_height = ksize[0]; + const int64_t ksize_width = ksize[1]; + + const int64_t stride_height = strides[0]; + const int64_t stride_width = strides[1]; + + const int64_t padding_height = paddings[0]; + const int64_t padding_width = paddings[1]; + + const int64_t dilation_height = dilations[0]; + const int64_t dilation_width = dilations[1]; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = dev_ctx.template Alloc(input_grad); + + int64_t nthreads = + batch_size * output_channels * output_height * output_width; + dim3 threads(kBlockThreads, 1); + + if (input.numel() <= std::numeric_limits::max() && + output.numel() <= std::numeric_limits::max()) { + auto pool_divmods = FastDivModForPooling( + input_channels, output_width, output_height); + if (FLAGS_use_accuracy_compatible_kernel) { + int64_t blocks = + (input_width * input_height + kBlockThreads - 1) / kBlockThreads; + dim3 grid(blocks, batch_size, input_channels); + // NOTE: input.numel() <= std::numeric_limits::max() && + // output.numel() <= std::numeric_limits::max() + KernelMaxPool2DWithDilationsGradCompatible + <<>>(input_data, + output_data, + output_grad_data, + batch_size, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + dilation_height, + dilation_width, + input_grad_data, + pool_divmods, + channel_last); + } else { + int64_t blocks = (nthreads + kBlockThreads - 1) / kBlockThreads; + dim3 grid(blocks, 1); + // NOTE: input.numel() <= std::numeric_limits::max() && + // output.numel() <= std::numeric_limits::max() + KernelMaxPool2DWithDilationsGrad + <<>>(nthreads, + input_data, + output_data, + output_grad_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + dilation_height, + dilation_width, + input_grad_data, + pool_divmods, + channel_last); + } + + } else { + auto pool_divmods = FastDivModForPooling( + input_channels, output_width, output_height); + if (FLAGS_use_accuracy_compatible_kernel) { + int64_t blocks = + (input_width * input_height + kBlockThreads - 1) / kBlockThreads; + dim3 grid(blocks, batch_size, input_channels); + KernelMaxPool2DWithDilationsGradCompatible + <<>>(input_data, + output_data, + output_grad_data, + batch_size, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + dilation_height, + dilation_width, + input_grad_data, + pool_divmods, + channel_last); + } else { + int64_t blocks = (nthreads + kBlockThreads - 1) / kBlockThreads; + dim3 grid(blocks, 1); + KernelMaxPool2DWithDilationsGrad + <<>>(nthreads, + input_data, + output_data, + output_grad_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + dilation_height, + dilation_width, + input_grad_data, + pool_divmods, + channel_last); + } + } + } +}; + template class PADDLE_API Pool2dDirectCUDAFunctor, float>; template class PADDLE_API Pool2dDirectCUDAFunctor, float>; @@ -1100,13 +1605,22 @@ template class MaxPool2dGradFunctor; template class MaxPool2dGradFunctor; template class MaxPool2dGradFunctor; +template class MaxPool2DWithDilationsGradFunctor; +template class MaxPool2DWithDilationsGradFunctor; +template class MaxPool2DWithDilationsGradFunctor; +template class MaxPool2DWithDilationsGradFunctor; + template class Pool2dFunctor, float>; +template class MaxPool2DWithDilationsFunctor; template class Pool2dFunctor, float>; template class Pool2dFunctor, float>; template class Pool2dGradFunctor, float>; template class Pool2dGradFunctor, float>; template class Pool2dGradFunctor, float>; template class Pool2dFunctor, double>; +template class MaxPool2DWithDilationsFunctor; template class Pool2dFunctor, double>; template class Pool2dFunctor, double>; template class Pool2dGradFunctor, double>; @@ -1116,6 +1630,7 @@ template class Pool2dGradFunctor, double>; template class Pool2dFunctor, dtype::float16>; +template class MaxPool2DWithDilationsFunctor; template class Pool2dFunctor, dtype::float16>; @@ -1134,6 +1649,7 @@ template class Pool2dGradFunctor, dtype::bfloat16>; +template class MaxPool2DWithDilationsFunctor; template class Pool2dFunctor, dtype::bfloat16>; @@ -1973,23 +2489,95 @@ template class Pool3dGradFunctor; template -__global__ void KernelMaxPool2dWithIdx(const IndexT nthreads, - const T1* input_data, - const IndexT channels, - const IndexT input_height, - const IndexT input_width, - const IndexT output_height, - const IndexT output_width, - const IndexT ksize_height, - const IndexT ksize_width, - const IndexT stride_height, - const IndexT stride_width, - const IndexT padding_height, - const IndexT padding_width, - bool adaptive, - T1* output_data, - T2* mask_data, - FastDivModForPooling divmods) { +__global__ void KernelMaxPool2dWithIdx(const IndexT nthreads, + const T1* input_data, + const IndexT channels, + const IndexT input_height, + const IndexT input_width, + const IndexT output_height, + const IndexT output_width, + const IndexT ksize_height, + const IndexT ksize_width, + const IndexT stride_height, + const IndexT stride_width, + const IndexT padding_height, + const IndexT padding_width, + bool adaptive, + T1* output_data, + T2* mask_data, + FastDivModForPooling divmods) { + const IndexT start_index = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const IndexT step = static_cast(blockDim.x) * gridDim.x; + for (IndexT index = start_index; index < nthreads; index += step) { + IndexT hstart, hend, wstart, wend; + IndexT w_offset, h_offset, c_offset, input_offset; + OffsetPreparationFor4Dimension, IndexT>( + index, + false, + divmods, + 0, + 0, + input_width, + input_height, + &w_offset, + &h_offset, + &c_offset, + &input_offset); + input_data += input_offset; + + if (adaptive) { + hstart = AdaptStartIndex(h_offset, input_height, output_height); + hend = AdaptEndIndex(h_offset, input_height, output_height); + + wstart = AdaptStartIndex(w_offset, input_width, output_width); + wend = AdaptEndIndex(w_offset, input_width, output_width); + } else { + hstart = h_offset * stride_height - padding_height; + hend = min(hstart + ksize_height, input_height); + hstart = max(hstart, static_cast(0)); + + wstart = w_offset * stride_width - padding_width; + wend = min(wstart + ksize_width, input_width); + wstart = max(wstart, static_cast(0)); + } + + T1 ele = static_cast(-FLT_MAX); + IndexT max_index = -1; + for (IndexT h = hstart; h < hend; ++h) { + for (IndexT w = wstart; w < wend; ++w) { + IndexT input_index = h * input_width + w; + if (ele < input_data[input_index]) { + max_index = input_index; + ele = input_data[input_index]; + } + } + } + output_data[index] = ele; + mask_data[index] = max_index; + } +} + +template +__global__ void KernelMaxPool2dWithDilationsAndIdx( + const IndexT nthreads, + const T1* input_data, + const IndexT channels, + const IndexT input_height, + const IndexT input_width, + const IndexT output_height, + const IndexT output_width, + const IndexT ksize_height, + const IndexT ksize_width, + const IndexT stride_height, + const IndexT stride_width, + const IndexT padding_height, + const IndexT padding_width, + const IndexT dilation_height, + const IndexT dilation_width, + T1* output_data, + T2* mask_data, + FastDivModForPooling divmods) { const IndexT start_index = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; const IndexT step = static_cast(blockDim.x) * gridDim.x; @@ -2010,26 +2598,35 @@ __global__ void KernelMaxPool2dWithIdx(const IndexT nthreads, &input_offset); input_data += input_offset; - if (adaptive) { - hstart = AdaptStartIndex(h_offset, input_height, output_height); - hend = AdaptEndIndex(h_offset, input_height, output_height); + hstart = h_offset * stride_height - padding_height; + hend = hstart + (ksize_height - 1) * dilation_height + 1; - wstart = AdaptStartIndex(w_offset, input_width, output_width); - wend = AdaptEndIndex(w_offset, input_width, output_width); - } else { - hstart = h_offset * stride_height - padding_height; - hend = min(hstart + ksize_height, input_height); - hstart = max(hstart, static_cast(0)); + hstart = + (hstart < static_cast(0)) + ? hstart + ((-hstart + dilation_height - 1) / dilation_height) * + dilation_height + : hstart; - wstart = w_offset * stride_width - padding_width; - wend = min(wstart + ksize_width, input_width); - wstart = max(wstart, static_cast(0)); - } + hend = (hend > input_height) + ? input_height - ((hend - input_height) % dilation_height) + : hend; + + wstart = w_offset * stride_width - padding_width; + wend = wstart + (ksize_width - 1) * dilation_width + 1; + + wstart = (wstart < static_cast(0)) + ? wstart + ((-wstart + dilation_width - 1) / dilation_width) * + dilation_width + : wstart; + + wend = (wend > input_width) + ? input_width - ((wend - input_width) % dilation_width) + : wend; T1 ele = static_cast(-FLT_MAX); IndexT max_index = -1; - for (IndexT h = hstart; h < hend; ++h) { - for (IndexT w = wstart; w < wend; ++w) { + for (IndexT h = hstart; h < hend; h += dilation_height) { + for (IndexT w = wstart; w < wend; w += dilation_width) { IndexT input_index = h * input_width + w; if (ele < input_data[input_index]) { max_index = input_index; @@ -2174,6 +2771,77 @@ __global__ void KernelMaxPool2DWithIdxGrad( } } +template +__global__ void KernelMaxPool2DWithDilationsAndIdxGrad( + const IndexT nthreads, + const T1* output_grad, + const T2* mask_data, + const IndexT channels, + const IndexT input_height, + const IndexT input_width, + const IndexT output_height, + const IndexT output_width, + const IndexT ksize_height, + const IndexT ksize_width, + const IndexT stride_height, + const IndexT stride_width, + const IndexT padding_height, + const IndexT padding_width, + const IndexT dilation_height, + const IndexT dilation_width, + T1* input_grad, + FastDivModForPooling divmods) { + const IndexT start_index = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const IndexT step = static_cast(blockDim.x) * gridDim.x; + for (IndexT index = start_index; index < nthreads; index += step) { + IndexT phstart, phend, pwstart, pwend; + IndexT w_offset, h_offset, c_offset, output_offset; + OffsetPreparationFor4Dimension, IndexT>( + index, + false, + divmods, + 0, + 0, + output_width, + output_height, + &w_offset, + &h_offset, + &c_offset, + &output_offset); + mask_data += output_offset; + output_grad += output_offset; + + phstart = + (h_offset + padding_height < ksize_height * dilation_height) + ? 0 + : (h_offset + padding_height - ksize_height * dilation_height) / + stride_height + + 1; + pwstart = (w_offset + padding_width < ksize_width * dilation_width) + ? 0 + : (w_offset + padding_width - ksize_width * dilation_width) / + stride_width + + 1; + phend = min((h_offset + padding_height) / stride_height + 1, output_height); + pwend = min((w_offset + padding_width) / stride_width + 1, output_width); + + T1 input_grad_data = static_cast(0); + IndexT input_current_featuremap_idx = h_offset * input_width + w_offset; + for (IndexT ph = phstart; ph < phend; ++ph) { + for (IndexT pw = pwstart; pw < pwend; ++pw) { + IndexT hstart = ph * stride_height - padding_height; + IndexT wstart = pw * stride_width - padding_width; + if (((h_offset - hstart) % dilation_height == 0) && + ((w_offset - wstart) % dilation_width == 0) && + (mask_data[ph * output_width + pw] == input_current_featuremap_idx)) + input_grad_data += output_grad[ph * output_width + pw]; + } + } + input_grad[index] = input_grad_data; + } +} + /* * All tensors are in NCHW format. * Ksize, strides, paddings are two elements. These two elements represent @@ -2322,6 +2990,94 @@ class MaxPool2dWithIndexFunctor { } }; +template +class MaxPool2dWithDilationsAndIndexFunctor { + public: + void operator()(const phi::GPUContext& dev_ctx, + const DenseTensor& input, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + DenseTensor* output, + DenseTensor* mask) { + const int64_t batch_size = input.dims()[0]; + const int64_t input_channels = input.dims()[1]; + const int64_t input_height = input.dims()[2]; + const int64_t input_width = input.dims()[3]; + const int64_t output_channels = output->dims()[1]; + const int64_t output_height = output->dims()[2]; + const int64_t output_width = output->dims()[3]; + const int64_t ksize_height = ksize[0]; + const int64_t ksize_width = ksize[1]; + const int64_t stride_height = strides[0]; + const int64_t stride_width = strides[1]; + const int64_t padding_height = paddings[0]; + const int64_t padding_width = paddings[1]; + const int64_t dilation_height = dilations[0]; + const int64_t dilation_width = dilations[1]; + + const T1* input_data = input.data(); + T1* output_data = dev_ctx.template Alloc(output); + T2* mask_data = dev_ctx.template Alloc(mask); + + int64_t nthreads = static_cast(batch_size) * output_channels * + output_height * output_width; + int thread_num = 1024; +#ifdef WITH_NV_JETSON + backends::gpu::ChangeThreadNum(dev_ctx, &thread_num); +#endif + int64_t blocks = (nthreads + thread_num - 1) / thread_num; + dim3 threads(thread_num, 1); + dim3 grid(blocks, 1); + if (input.numel() <= std::numeric_limits::max()) { + auto pool_divmods = FastDivModForPooling( + input_channels, output_width, output_height); + KernelMaxPool2dWithDilationsAndIdx + <<>>(nthreads, + input_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + dilation_height, + dilation_width, + output_data, + mask_data, + pool_divmods); + } else { + auto pool_divmods = FastDivModForPooling( + input_channels, output_width, output_height); + KernelMaxPool2dWithDilationsAndIdx + <<>>(nthreads, + input_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + dilation_height, + dilation_width, + output_data, + mask_data, + pool_divmods); + } + } +}; + /* * All tensors are in NCHW format. * Ksize, strides, paddings are two elements. These two elements represent @@ -2407,6 +3163,90 @@ class MaxPool2dWithIndexGradFunctor { } }; +template +class MaxPool2dWithDilationsAndIndexGradFunctor { + public: + void operator()(const phi::GPUContext& dev_ctx, + const DenseTensor& output_grad, + const DenseTensor& mask, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + DenseTensor* input_grad) { + const int64_t batch_size = input_grad->dims()[0]; + const int64_t input_channels = input_grad->dims()[1]; + const int64_t input_height = input_grad->dims()[2]; + const int64_t input_width = input_grad->dims()[3]; + const int64_t output_height = output_grad.dims()[2]; + const int64_t output_width = output_grad.dims()[3]; + const int64_t ksize_height = ksize[0]; + const int64_t ksize_width = ksize[1]; + const int64_t stride_height = strides[0]; + const int64_t stride_width = strides[1]; + const int64_t padding_height = paddings[0]; + const int64_t padding_width = paddings[1]; + const int64_t dilation_height = dilations[0]; + const int64_t dilation_width = dilations[1]; + + const T2* mask_data = mask.data(); + const T1* output_grad_data = output_grad.data(); + T1* input_grad_data = dev_ctx.template Alloc(input_grad); + + int64_t nthreads = static_cast(batch_size) * input_channels * + input_height * input_width; + int64_t blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + if (nthreads <= std::numeric_limits::max()) { + auto pool_divmods = + FastDivModForPooling(input_channels, input_width, input_height); + KernelMaxPool2DWithDilationsAndIdxGrad + <<>>(nthreads, + output_grad_data, + mask_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + dilation_height, + dilation_width, + input_grad_data, + pool_divmods); + } else { + auto pool_divmods = FastDivModForPooling( + input_channels, input_width, input_height); + KernelMaxPool2DWithDilationsAndIdxGrad + <<>>(nthreads, + output_grad_data, + mask_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + dilation_height, + dilation_width, + input_grad_data, + pool_divmods); + } + } +}; + template class MaxPool2dWithIndexFunctor; template class MaxPool2dWithIndexGradFunctor; template class MaxPool2dWithIndexFunctor; @@ -2420,6 +3260,31 @@ template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithDilationsAndIndexFunctor; +template class MaxPool2dWithDilationsAndIndexGradFunctor; +template class MaxPool2dWithDilationsAndIndexFunctor; +template class MaxPool2dWithDilationsAndIndexGradFunctor; +template class MaxPool2dWithDilationsAndIndexFunctor; +template class MaxPool2dWithDilationsAndIndexGradFunctor; +template class MaxPool2dWithDilationsAndIndexFunctor; +template class MaxPool2dWithDilationsAndIndexGradFunctor; + template __global__ void KernelMaxPool3DWithIdx( const IndexT ncd, diff --git a/paddle/phi/kernels/funcs/pooling.h b/paddle/phi/kernels/funcs/pooling.h index af13745d27eda8..d1591fd3607fae 100644 --- a/paddle/phi/kernels/funcs/pooling.h +++ b/paddle/phi/kernels/funcs/pooling.h @@ -222,6 +222,32 @@ class Pool2dFunctor { PoolProcess pool_compute); }; +template +class MaxPool2DWithDilationsFunctor { + public: + void operator()(const Context& context, + const DenseTensor& input, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + const std::string data_format, + DenseTensor* output); +}; + +template +class MaxPool2dWithDilationsAndIndexFunctor { + public: + void operator()(const Context& context, + const DenseTensor& input, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + DenseTensor* output, + DenseTensor* mask); +}; + template class Pool2dGradFunctor { public: @@ -239,6 +265,34 @@ class Pool2dGradFunctor { PoolProcess pool_compute); }; +template +class MaxPool2DWithDilationsGradFunctor { + public: + void operator()(const Context& context, + const DenseTensor& input, + const DenseTensor& output, + const DenseTensor& output_grad, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + const std::string data_format, + DenseTensor* input_grad); +}; + +template +class MaxPool2dWithDilationsAndIndexGradFunctor { + public: + void operator()(const Context& context, + const DenseTensor& output_grad, + const DenseTensor& mask, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + DenseTensor* input_grad); +}; + template class MaxPool2dGradFunctor { public: diff --git a/paddle/phi/kernels/gpu/pool_grad_kernel.cu b/paddle/phi/kernels/gpu/pool_grad_kernel.cu index 4c38158e1d7c3d..113e72f1aa2be2 100644 --- a/paddle/phi/kernels/gpu/pool_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/pool_grad_kernel.cu @@ -25,6 +25,14 @@ PD_REGISTER_KERNEL(pool2d_grad, double, phi::float16, phi::bfloat16) {} +PD_REGISTER_KERNEL(max_pool2d_with_dilations_grad, + GPU, + ALL_LAYOUT, + phi::MaxPool2DWithDilationsGradKernel, + float, + double, + phi::float16, + phi::bfloat16) {} PD_REGISTER_KERNEL(lp_pool2d_grad, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/gpu/pool_kernel.cu b/paddle/phi/kernels/gpu/pool_kernel.cu index 79e20516b6f676..ecafd6876bb4e8 100644 --- a/paddle/phi/kernels/gpu/pool_kernel.cu +++ b/paddle/phi/kernels/gpu/pool_kernel.cu @@ -25,6 +25,14 @@ PD_REGISTER_KERNEL(pool2d, double, phi::float16, phi::bfloat16) {} +PD_REGISTER_KERNEL(max_pool2d_with_dilations, + GPU, + ALL_LAYOUT, + phi::MaxPool2DWithDilationsKernel, + float, + double, + phi::float16, + phi::bfloat16) {} PD_REGISTER_KERNEL(lp_pool2d, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/impl/pool_grad_kernel_impl.h b/paddle/phi/kernels/impl/pool_grad_kernel_impl.h index 39059edeca4c37..0782720dfe948a 100644 --- a/paddle/phi/kernels/impl/pool_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/pool_grad_kernel_impl.h @@ -169,6 +169,73 @@ void PoolGradRawKernel(const Context& dev_ctx, } } +template +void MaxPool2DWithDilationsGradRawKernel( + const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& dout, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + const std::string& data_format, + bool global_pooling, + const std::string& padding_algorithm, + DenseTensor* dx) { + if (dx && dx->numel() == 0) { + dev_ctx.template Alloc(dx); + return; + } + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + std::vector paddings_ = paddings; + std::vector dilations_ = dilations; + std::vector kernel_size_ = kernel_size; + + // update paddings + auto x_dims = x.dims(); + DDim data_dims; + if (channel_last) { + data_dims = slice_ddim(x_dims, 1, x_dims.size() - 1); + } else { + data_dims = slice_ddim(x_dims, 2, x_dims.size()); + } + funcs::UpdatePadding(&paddings_, + global_pooling, + false, + padding_algorithm, + data_dims, + strides, + kernel_size_); + if (data_dims.size() * 2 == static_cast(paddings_.size())) { + for (int i = 0; i < data_dims.size(); ++i) { + paddings_.erase(paddings_.begin() + i + 1); + } + } + + if (global_pooling) { + funcs::UpdateKernelSize(&kernel_size_, data_dims); + } + + if (dx) { + dev_ctx.template Alloc(dx); + funcs::SetConstant set_constant; + set_constant(dev_ctx, dx, static_cast(0.0)); + + funcs::MaxPool2DWithDilationsGradFunctor pool2d_backward; + pool2d_backward(dev_ctx, + x, + out, + dout, + kernel_size_, + strides, + paddings_, + dilations_, + data_format, + dx); + } +} + template void MaxPoolWithIndexGradRawKernel(const Context& dev_ctx, const DenseTensor& x UNUSED, @@ -230,6 +297,45 @@ void MaxPoolWithIndexGradRawKernel(const Context& dev_ctx, } } +template +void MaxPool2dWithDilationsAndIndexGradRawKernel( + const Context& dev_ctx, + const DenseTensor& x UNUSED, + const DenseTensor& mask, + const DenseTensor& dout, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + bool global_pooling, + DenseTensor* dx) { + if (dx && dx->numel() == 0) { + dev_ctx.template Alloc(dx); + return; + } + std::vector paddings_(paddings.begin(), paddings.end()); + std::vector kernel_size_(kernel_size.begin(), kernel_size.end()); + std::vector strides_(strides.begin(), strides.end()); + std::vector dilations_(dilations.begin(), dilations.end()); + + if (global_pooling) { + for (size_t i = 0; i < kernel_size_.size(); ++i) { + paddings_[i] = 0; + kernel_size_[i] = static_cast(dx->dims()[i + 2]); + } + } + + if (dx) { + dev_ctx.template Alloc(dx); + funcs::set_constant(dev_ctx, dx, static_cast(0)); + + funcs::MaxPool2dWithDilationsAndIndexGradFunctor + pool2d_backward; + pool2d_backward( + dev_ctx, dout, mask, kernel_size_, strides_, paddings_, dilations_, dx); + } +} + template void Pool2dGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -263,6 +369,34 @@ void Pool2dGradKernel(const Context& dev_ctx, dx); } +template +void MaxPool2DWithDilationsGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& dout, + const IntArray& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + bool ceil_mode UNUSED, + const std::string& data_format, + bool global_pooling, + const std::string& padding_algorithm, + DenseTensor* dx) { + MaxPool2DWithDilationsGradRawKernel(dev_ctx, + x, + out, + dout, + kernel_size.GetData(), + strides, + paddings, + dilations, + data_format, + global_pooling, + padding_algorithm, + dx); +} + template void LPPool2dGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -339,10 +473,25 @@ void MaxPool2dWithIndexGradKernel(const Context& dev_ctx, const std::vector& kernel_size, const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, bool global_pooling, bool adaptive, bool ceil_mode UNUSED, DenseTensor* dx) { + if (dilations[0] > 1 || dilations[1] > 1) { + MaxPool2dWithDilationsAndIndexGradRawKernel(dev_ctx, + x, + mask, + dout, + kernel_size, + strides, + paddings, + dilations, + global_pooling, + dx); + return; + } + MaxPoolWithIndexGradRawKernel(dev_ctx, x, mask, diff --git a/paddle/phi/kernels/impl/pool_kernel_impl.h b/paddle/phi/kernels/impl/pool_kernel_impl.h index 87e74f1786e469..c4249fbb788720 100644 --- a/paddle/phi/kernels/impl/pool_kernel_impl.h +++ b/paddle/phi/kernels/impl/pool_kernel_impl.h @@ -217,6 +217,64 @@ void PoolRawKernel(const Context& dev_ctx, } } +template +void MaxPool2DWithDilationsRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + const std::string& data_format, + bool global_pooling, + const std::string& padding_algorithm, + DenseTensor* out) { + if (x.numel() == 0) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out); + return; + } + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + std::vector paddings_ = paddings; + std::vector dilations_ = dilations; + std::vector kernel_size_ = kernel_size; + + // update paddings + auto x_dims = x.dims(); + DDim data_dims; + if (channel_last) { + data_dims = slice_ddim(x_dims, 1, x_dims.size() - 1); + } else { + data_dims = slice_ddim(x_dims, 2, x_dims.size()); + } + + funcs::UpdatePadding(&paddings_, + global_pooling, + false, + padding_algorithm, + data_dims, + strides, + kernel_size_); + + if (data_dims.size() * 2 == static_cast(paddings_.size())) { + for (int i = 0; i < data_dims.size(); ++i) { + paddings_.erase(paddings_.begin() + i + 1); + } + } + + if (global_pooling) { + funcs::UpdateKernelSize(&kernel_size_, data_dims); + } + funcs::MaxPool2DWithDilationsFunctor pool2d_forward; + pool2d_forward(dev_ctx, + x, + kernel_size_, + strides, + paddings_, + dilations_, + data_format, + out); +} + template void MaxPoolWithIndexRawKernel(const Context& dev_ctx, const DenseTensor& x, @@ -268,6 +326,46 @@ void MaxPoolWithIndexRawKernel(const Context& dev_ctx, } } +template +void MaxPool2dWithDilationsAndIndexRawKernel( + const Context& dev_ctx, + const DenseTensor& x, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + bool global_pooling, + DenseTensor* out, + DenseTensor* mask) { + if (x.numel() == 0) { + if (out) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out); + } + if (mask) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(mask->dims())), 0, mask); + } + return; + } + + std::vector paddings_(paddings.begin(), paddings.end()); + std::vector dilations_(dilations.begin(), dilations.end()); + std::vector kernel_size_(kernel_size.begin(), kernel_size.end()); + std::vector strides_(strides.begin(), strides.end()); + + if (global_pooling) { + for (size_t i = 0; i < kernel_size_.size(); ++i) { + paddings_[i] = 0; + kernel_size_[i] = static_cast(x.dims()[i + 2]); + } + } + + funcs::MaxPool2dWithDilationsAndIndexFunctor pool2d_forward; + pool2d_forward( + dev_ctx, x, kernel_size_, strides_, paddings_, dilations_, out, mask); +} + template void Pool2dKernel(const Context& dev_ctx, const DenseTensor& x, @@ -307,6 +405,60 @@ void Pool2dKernel(const Context& dev_ctx, out); } +template +void MaxPool2DWithDilationsKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + bool ceil_mode UNUSED, + const std::string& data_format, + bool global_pooling, + const std::string& padding_algorithm, + DenseTensor* out) { + if (x.numel() == 0) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out); + return; + } + + if (dilations[0] <= static_cast(0) || + dilations[1] <= static_cast(0)) { + PADDLE_THROW(errors::InvalidArgument( + "The dilations of MaxPool2D op must be >= 0, but received dilations = " + "[%ld, %ld].", + dilations[0], + dilations[1])); + } else if (dilations[0] > static_cast(1) || + dilations[1] > static_cast(1)) { + MaxPool2DWithDilationsRawKernel(dev_ctx, + x, + kernel_size.GetData(), + strides, + paddings, + dilations, + data_format, + global_pooling, + padding_algorithm, + out); + } else { + PoolRawKernel(dev_ctx, + x, + kernel_size.GetData(), + strides, + paddings, + true, + data_format, + "max", + false, + false, + padding_algorithm, + 0, + out); + } +} + template void LPPool2dKernel(const Context& dev_ctx, const DenseTensor& x, @@ -358,11 +510,24 @@ void MaxPool2dWithIndexKernel(const Context& dev_ctx, const std::vector& kernel_size, const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, bool global_pooling, bool adaptive, bool ceil_mode UNUSED, DenseTensor* out, DenseTensor* mask) { + if (dilations[0] > 1 || dilations[1] > 1) { + MaxPool2dWithDilationsAndIndexRawKernel(dev_ctx, + x, + kernel_size, + strides, + paddings, + dilations, + global_pooling, + out, + mask); + return; + } MaxPoolWithIndexRawKernel(dev_ctx, x, kernel_size, diff --git a/paddle/phi/kernels/pool_grad_kernel.h b/paddle/phi/kernels/pool_grad_kernel.h index efcb561980d54c..de919964d62aa8 100644 --- a/paddle/phi/kernels/pool_grad_kernel.h +++ b/paddle/phi/kernels/pool_grad_kernel.h @@ -39,6 +39,21 @@ void Pool2dGradKernel(const Context& dev_ctx, const std::string& padding_algorithm, DenseTensor* dx); +template +void MaxPool2DWithDilationsGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& dout, + const IntArray& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + bool ceil_mode, + const std::string& data_format, + bool global_pooling, + const std::string& padding_algorithm, + DenseTensor* dx); + template void LPPool2dGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -112,6 +127,7 @@ void MaxPool2dWithIndexGradKernel(const Context& dev_ctx, const std::vector& kernel_size, const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, bool global_pooling, bool adaptive, bool ceil_mode, diff --git a/paddle/phi/kernels/pool_kernel.h b/paddle/phi/kernels/pool_kernel.h index b7f929f4d0cb1d..8a98902632ed04 100644 --- a/paddle/phi/kernels/pool_kernel.h +++ b/paddle/phi/kernels/pool_kernel.h @@ -37,6 +37,19 @@ void Pool2dKernel(const Context& dev_ctx, const std::string& padding_algorithm, DenseTensor* out); +template +void MaxPool2DWithDilationsKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + bool ceil_mode, + const std::string& data_format, + bool global_pooling, + const std::string& padding_algorithm, + DenseTensor* out); + template void LPPool2dKernel(const Context& dev_ctx, const DenseTensor& x, @@ -74,6 +87,7 @@ void MaxPool2dWithIndexKernel(const Context& dev_ctx, const std::vector& kernel_size, const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, bool global_pooling, bool adaptive, bool ceil_mode, diff --git a/paddle/phi/kernels/xpu/pool_grad_kernel.cc b/paddle/phi/kernels/xpu/pool_grad_kernel.cc index dde1f7e8869918..92b93f161b8ca0 100644 --- a/paddle/phi/kernels/xpu/pool_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_grad_kernel.cc @@ -393,6 +393,7 @@ void MaxPool2dWithIndexGradKernel(const Context& dev_ctx, const std::vector& kernel_size_t, const std::vector& strides_t, const std::vector& paddings_t, + const std::vector& dilations_t, bool global_pooling, bool adaptive, bool ceil_mode UNUSED, diff --git a/paddle/phi/kernels/xpu/pool_kernel.cc b/paddle/phi/kernels/xpu/pool_kernel.cc index 4d9fcd461bd892..803638117e0294 100644 --- a/paddle/phi/kernels/xpu/pool_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_kernel.cc @@ -319,6 +319,7 @@ void MaxPool2dWithIndexKernel(const Context& dev_ctx, const std::vector& kernel_size_t, const std::vector& strides_t, const std::vector& paddings_t, + const std::vector& dilations_t, bool global_pooling, bool adaptive, bool ceil_mode UNUSED, diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index bf0c09ca9da7e7..c60ce9121d64db 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -2314,12 +2314,24 @@ func : max_grad composite : max_grad(x, out, out_grad, axis, keepdim, reduce_all, x_grad) +- backward_op : max_pool2d_with_dilations_grad + forward : max_pool2d_with_dilations(Tensor x, IntArray kernel_size, int64_t[] strides, int64_t[] paddings, int64_t[] dilations, bool ceil_mode, str data_format, bool global_pooling, str padding_algorithm) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, IntArray kernel_size, int64_t[] strides, int64_t[] paddings, int64_t[] dilations, bool ceil_mode, str data_format, bool global_pooling, str padding_algorithm) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : max_pool2d_with_dilations_grad + param : [x, out, out_grad, kernel_size, strides, paddings, dilations, ceil_mode, data_format, global_pooling, padding_algorithm] + interfaces : paddle::dialect::InferSymbolicShapeInterface + - backward_op : max_pool2d_with_index_grad - forward : max_pool2d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false, bool ceil_mode = false) -> Tensor(out), Tensor(mask) - args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive, bool ceil_mode = false) + forward : max_pool2d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1}, int[] paddings = {0, 0}, int[] dilations = {1, 1}, bool global_pooling = false, bool adaptive = false, bool ceil_mode = false) -> Tensor(out), Tensor(mask) + args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, int[] dilations, bool global_pooling, bool adaptive, bool ceil_mode = false) output : Tensor(x_grad) infer_meta : - func : MaxPoolWithIndexGradInferMeta + func : MaxPool2dWithIndexGradInferMeta kernel : func : max_pool2d_with_index_grad diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index fa4999de08ee90..c094535ddd757d 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -2544,6 +2544,26 @@ max_grad : GetReduceGradExpectedKernelType manual_signature : [max] +- op : max_pool2d_with_dilations + backward : max_pool2d_with_dilations_grad + inputs : + {x : X} + outputs : + {out : Out} + attrs : + {kernel_size : ksize} + int_array: + kernel_size : + data_type : int + support_tensor : true + get_expected_kernel_type : + max_pool2d_with_dilations : GetPoolExpectedKernelType + max_pool2d_with_dilations_grad : GetPoolExpectedKernelType + extra : + attrs : [bool use_mkldnn = false, bool use_onednn = false, bool use_quantizer = false, + str mkldnn_data_type = "float32", str onednn_data_type = "", bool is_test = false, + bool adaptive = false] + - op : max_pool2d_with_index inputs : {x : X} diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 4095e94feaf356..b946bdd1caed8f 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3587,11 +3587,23 @@ backward : max_grad interfaces : paddle::dialect::InferSymbolicShapeInterface, paddle::dialect::LayoutTransformationInterface +- op : max_pool2d_with_dilations + args : (Tensor x, IntArray kernel_size, int64_t[] strides, int64_t[] paddings, int64_t[] dilations, bool ceil_mode, str data_format, bool global_pooling, str padding_algorithm) + output : Tensor(out) + infer_meta : + func : MaxPool2DWithDilationsInferMeta + param : [x, kernel_size, strides, paddings, dilations, ceil_mode, data_format, global_pooling, padding_algorithm] + kernel : + func : max_pool2d_with_dilations + param : [x, kernel_size, strides, paddings, dilations, ceil_mode, data_format, global_pooling, padding_algorithm] + backward : max_pool2d_with_dilations_grad + interfaces : paddle::dialect::LayoutTransformationInterface, paddle::dialect::InferSymbolicShapeInterface + - op : max_pool2d_with_index - args : (Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false, bool ceil_mode = false) + args : (Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, int[] dilations = {1, 1}, bool global_pooling = false, bool adaptive = false, bool ceil_mode = false) output : Tensor(out), Tensor(mask) infer_meta : - func : MaxPoolWithIndexInferMeta + func : MaxPool2dWithIndexInferMeta kernel : func : max_pool2d_with_index backward : max_pool2d_with_index_grad diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 1988a424ac7ee0..17cbdd066d4cc0 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -642,11 +642,19 @@ def max_pool1d( # use 2d to implement 1d should expand padding in advance. padding = _expand_low_nd_padding(padding) + dilation = convert_to_list(1, 2, 'pool_dilation') if in_dynamic_or_pir_mode(): if return_mask: pool_out = _C_ops.max_pool2d_with_index( - x, kernel_size, stride, padding, False, False, ceil_mode + x, + kernel_size, + stride, + padding, + dilation, + False, + False, + ceil_mode, ) return ( (squeeze(pool_out[0], [2]), squeeze(pool_out[1], [2])) @@ -678,23 +686,43 @@ def max_pool1d( mask = helper.create_variable_for_type_inference('int32') outputs = {"Out": pool_out, "Mask": mask} - helper.append_op( - type=op_type, - inputs={"X": x}, - outputs=outputs, - attrs={ - "pooling_type": 'max', - "ksize": kernel_size, - "global_pooling": False, - "strides": stride, - "paddings": padding, - "padding_algorithm": padding_algorithm, - "use_cudnn": True, - "ceil_mode": ceil_mode, - "exclusive": True, - "data_format": data_format, - }, - ) + if return_mask: + helper.append_op( + type=op_type, + inputs={"X": x}, + outputs=outputs, + attrs={ + "pooling_type": 'max', + "ksize": kernel_size, + "global_pooling": False, + "strides": stride, + "paddings": padding, + "dilations": dilation, + "padding_algorithm": padding_algorithm, + "use_cudnn": True, + "ceil_mode": ceil_mode, + "exclusive": True, + "data_format": data_format, + }, + ) + else: + helper.append_op( + type=op_type, + inputs={"X": x}, + outputs=outputs, + attrs={ + "pooling_type": 'max', + "ksize": kernel_size, + "global_pooling": False, + "strides": stride, + "paddings": padding, + "padding_algorithm": padding_algorithm, + "use_cudnn": True, + "ceil_mode": ceil_mode, + "exclusive": True, + "data_format": data_format, + }, + ) return ( (squeeze(pool_out, [2]), squeeze(mask, [2])) @@ -1148,6 +1176,7 @@ def max_pool2d( kernel_size: Size2, stride: Size2 | None = None, padding: _PaddingSizeMode | Size2 | Size4 = 0, + dilation: Size2 = 1, return_mask: bool = False, ceil_mode: bool = False, data_format: DataLayout2D = 'NCHW', @@ -1176,6 +1205,9 @@ def max_pool2d( 4. A list[int] or tuple(int) whose length is 4. [pad_height_top, pad_height_bottom, pad_width_left, pad_width_right] whose value means the padding size of each side. 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. + dilation(int|list|tuple): The dilation size. Dilation could be in one of the following forms. + 1. An int, which specifies the same dilation size for both the height and width dimensions. + 2. A list[int] or tuple(int) whose length is 2, [dilation_height, dilation_weight] whose value means the dilation size of each dimension. ceil_mode (bool): when True, will use `ceil` instead of `floor` to compute the output shape return_mask (bool): Whether to return the max indices along with the outputs. Default False, only support `"NCHW"` data format data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`. @@ -1195,11 +1227,11 @@ def max_pool2d( >>> # max pool2d >>> x = paddle.uniform([1, 3, 32, 32], paddle.float32) - >>> out = F.max_pool2d(x, kernel_size=2, stride=2, padding=0) + >>> out = F.max_pool2d(x, kernel_size=2, stride=2, padding=0, dilation=1) >>> print(out.shape) paddle.Size([1, 3, 16, 16]) >>> # for return_mask=True - >>> out, max_indices = F.max_pool2d(x, kernel_size=2, stride=2, padding=0, return_mask=True) + >>> out, max_indices = F.max_pool2d(x, kernel_size=2, stride=2, padding=0, dilation=1, return_mask=True) >>> print(out.shape) paddle.Size([1, 3, 16, 16]) >>> print(max_indices.shape) @@ -1229,29 +1261,40 @@ def max_pool2d( "When setting return_mask to true, data_format must be set to NCHW in API:max_pool2d" ) + dilation = convert_to_list(dilation, 2, 'pool_dilation') + if in_dynamic_or_pir_mode(): if return_mask: output = _C_ops.max_pool2d_with_index( - x, kernel_size, stride, padding, False, False, ceil_mode + x, + kernel_size, + stride, + padding, + dilation, + False, + False, + ceil_mode, ) return output if return_mask else output[0] else: - return _C_ops.pool2d( + return _C_ops.max_pool2d_with_dilations( x, kernel_size, stride, padding, + dilation, ceil_mode, - True, data_format, - 'max', - False, False, padding_algorithm, ) else: - op_type = 'max_pool2d_with_index' if return_mask else "pool2d" + if return_mask: + op_type = 'max_pool2d_with_index' + else: + op_type = 'max_pool2d_with_dilations' + helper = LayerHelper(op_type, **locals()) check_variable_and_dtype( x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'max_pool2d' @@ -1262,7 +1305,6 @@ def max_pool2d( if return_mask: mask = helper.create_variable_for_type_inference("int32") outputs = {"Out": pool_out, "Mask": mask} - helper.append_op( type="max_pool2d_with_index", inputs={"X": x}, @@ -1284,24 +1326,23 @@ def max_pool2d( else: outputs = {"Out": pool_out} - helper.append_op( - type="pool2d", + type="max_pool2d_with_dilations", inputs={"X": x}, outputs=outputs, attrs={ - "pooling_type": 'max', "ksize": kernel_size, "global_pooling": False, "strides": stride, "paddings": padding, + "dilations": dilation, "padding_algorithm": padding_algorithm, - "use_cudnn": True, "ceil_mode": ceil_mode, - "exclusive": True, "data_format": data_format, + "use_cudnn": True, }, ) + return pool_out @@ -1861,9 +1902,10 @@ def adaptive_max_pool1d( pool_size = [1, *convert_to_list(output_size, 1, "pool_size")] x = unsqueeze(x, [2]) + dilation = convert_to_list(1, 2, 'pool_dilation') if in_dynamic_or_pir_mode(): pool_out = _C_ops.max_pool2d_with_index( - x, pool_size, [1, 1], [0, 0], False, True, False + x, pool_size, [1, 1], [0, 0], dilation, False, True, False ) return ( (squeeze(pool_out[0], [2]), squeeze(pool_out[1], [2])) @@ -1893,6 +1935,7 @@ def adaptive_max_pool1d( attrs={ "pooling_type": 'max', "ksize": pool_size, + "dilations": dilation, "adaptive": True, "ceil_mode": False, }, @@ -1953,6 +1996,7 @@ def adaptive_max_pool2d( _check_input(x, 4) in_h, in_w = x.shape[2:4] + dilation = convert_to_list(1, 2, 'pool_dilation') if isinstance(output_size, int): output_size = convert_to_list(output_size, 2, 'output_size') else: @@ -1963,7 +2007,7 @@ def adaptive_max_pool2d( output_size[1] = in_w if in_dynamic_or_pir_mode(): pool_out = _C_ops.max_pool2d_with_index( - x, output_size, [1, 1], [0, 0], False, True, False + x, output_size, [1, 1], [0, 0], dilation, False, True, False ) return pool_out if return_mask else pool_out[0] else: @@ -1989,6 +2033,7 @@ def adaptive_max_pool2d( attrs={ "pooling_type": 'max', "ksize": output_size, + "dilations": dilation, "adaptive": True, "ceil_mode": False, }, diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index ae21a05043beaf..92e463bd585d40 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -743,6 +743,9 @@ class MaxPool2D(Layer): 4. A list[int] or tuple(int) whose length is \4. [pad_height_top, pad_height_bottom, pad_width_left, pad_width_right] whose value means the padding size of each side. 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. + dilation(int|list|tuple): The dilation size. Dilation could be in one of the following forms. + 1. An int, which specifies the same dilation size for both the height and width dimensions. + 2. A list[int] or tuple(int) whose length is 2, [dilation_height, dilation_weight] whose value means the dilation size of each dimension. ceil_mode(bool, optional): when True, will use `ceil` instead of `floor` to compute the output shape return_mask(bool, optional): Whether to return the max indices along with the outputs. data_format(str, optional): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`. @@ -768,13 +771,13 @@ class MaxPool2D(Layer): >>> # max pool2d >>> input = paddle.uniform([1, 3, 32, 32], dtype="float32", min=-1, max=1) - >>> MaxPool2D = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + >>> MaxPool2D = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, dilation=1) >>> output = MaxPool2D(input) >>> print(output.shape) paddle.Size([1, 3, 16, 16]) >>> # for return_mask=True - >>> MaxPool2D = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, return_mask=True) + >>> MaxPool2D = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, dilation=1, return_mask=True) >>> output, max_indices = MaxPool2D(input) >>> print(output.shape) paddle.Size([1, 3, 16, 16]) @@ -785,6 +788,7 @@ class MaxPool2D(Layer): kernel_size: Size2 stride: Size2 | None padding: _PaddingSizeMode | Size2 | Size4 + dilation: Size2 return_mask: bool ceil_mode: bool data_format: DataLayout2D @@ -795,6 +799,7 @@ def __init__( kernel_size: Size2, stride: Size2 | None = None, padding: _PaddingSizeMode | Size2 | Size4 = 0, + dilation: Size2 = 1, return_mask: bool = False, ceil_mode: bool = False, data_format: DataLayout2D = 'NCHW', @@ -804,6 +809,7 @@ def __init__( self.ksize = kernel_size self.stride = stride self.padding = padding + self.dilation = dilation self.return_mask = return_mask self.ceil_mode = ceil_mode self.data_format = data_format @@ -815,6 +821,7 @@ def forward(self, x: Tensor) -> Tensor: kernel_size=self.ksize, stride=self.stride, padding=self.padding, + dilation=self.dilation, return_mask=self.return_mask, ceil_mode=self.ceil_mode, data_format=self.data_format, diff --git a/test/dygraph_to_static/test_save_load.py b/test/dygraph_to_static/test_save_load.py index bc5f5a7eee139e..51cbf4cb93c4af 100644 --- a/test/dygraph_to_static/test_save_load.py +++ b/test/dygraph_to_static/test_save_load.py @@ -154,7 +154,7 @@ def test_save_load_prim(self): self.assertIn("pd_op.conv2d", load_op_type_list) self.assertIn("pd_op.batch_norm_", load_op_type_list) self.assertIn("pd_op.relu", load_op_type_list) - self.assertIn("pd_op.pool2d", load_op_type_list) + self.assertIn("pd_op.max_pool2d_with_dilations", load_op_type_list) np.testing.assert_allclose(res.numpy(), new_res.numpy(), rtol=1e-05) @test_ast_only @@ -195,7 +195,7 @@ def test_save_load_prim_with_hook(self): self.assertIn("pd_op.conv2d", load_op_type_list) self.assertIn("pd_op.batch_norm_", load_op_type_list) self.assertIn("pd_op.relu", load_op_type_list) - self.assertIn("pd_op.pool2d", load_op_type_list) + self.assertIn("pd_op.max_pool2d_with_dilations", load_op_type_list) np.testing.assert_allclose(res.numpy(), new_res.numpy(), rtol=1e-05) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index e995051736e520..fee93bb1427652 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -50,7 +50,6 @@ list(REMOVE_ITEM TEST_OPS test_transformer_api) list(REMOVE_ITEM TEST_OPS test_conv2d_transpose_op) list(REMOVE_ITEM TEST_OPS test_fractional_max_pool2d_op) list(REMOVE_ITEM TEST_OPS test_conv2d_op) -list(REMOVE_ITEM TEST_OPS test_pool_max_op) list(REMOVE_ITEM TEST_OPS test_matmul_v2_op) list(REMOVE_ITEM TEST_OPS test_allgather) diff --git a/test/legacy_test/test_pool2d_api.py b/test/legacy_test/test_pool2d_api.py index ae5d8dc9fb0216..8d4025198b25dd 100644 --- a/test/legacy_test/test_pool2d_api.py +++ b/test/legacy_test/test_pool2d_api.py @@ -19,8 +19,10 @@ from test_pool2d_op import ( avg_pool2D_forward_naive, max_pool2D_forward_naive, + max_pool2d_with_dilations_forward_naive, pool2D_forward_naive, ) +from test_pool_max_op import max_pool2d_with_dilations_and_index_forward_naive import paddle from paddle import base @@ -150,6 +152,40 @@ def check_max_static_results(self, place): ) np.testing.assert_allclose(fetches[0], result_np, rtol=1e-05) + def check_max_with_index_static_results(self, place): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + input = paddle.static.data( + name="input", shape=[2, 3, 32, 32], dtype="float32" + ) + out, mask = max_pool2d( + input, + kernel_size=2, + stride=2, + padding=1, + dilation=2, + return_mask=True, + ) + + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + out_np, mask_np = max_pool2d_with_dilations_and_index_forward_naive( + input_np, + ksize=[2, 2], + strides=[2, 2], + paddings=[1, 1], + dilations=[2, 2], + global_pool=False, + ) + + exe = base.Executor(place) + fetches = exe.run( + feed={"input": input_np}, + fetch_list=[out, mask], + ) + np.testing.assert_allclose(fetches[0], out_np, rtol=1e-05) + np.testing.assert_allclose(fetches[1], mask_np, rtol=1e-05) + def check_max_dygraph_results(self, place): with base.dygraph.guard(place): input_np = np.random.random([2, 3, 32, 32]).astype("float32") @@ -790,6 +826,7 @@ def test_pool2d_static(self): paddle.enable_static() for place in self.places: self.check_max_static_results(place) + self.check_max_with_index_static_results(place) self.check_avg_static_results(place) self.check_lp_static_results(place) self.check_lp_float64_static(place) @@ -801,8 +838,59 @@ def test_torch_compatible(self): paddle.enable_static() for place in self.places: self.check_max_static_results(place) + self.check_max_with_index_static_results(place) paddle.disable_static() + def check_max_dygraph_with_dilations(self, place): + with base.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + input = paddle.to_tensor(input_np) + + result = max_pool2d( + input, + kernel_size=3, + stride=1, + padding=1, + dilation=[2, 2], + return_mask=False, + ) + + result_np = max_pool2d_with_dilations_forward_naive( + input_np, + ksize=[3, 3], + strides=[1, 1], + paddings=[1, 1], + dilations=[2, 2], + ) + + np.testing.assert_allclose(result.numpy(), result_np, rtol=1e-05) + + def check_max_dygraph_with_dilations_and_index(self, place): + with base.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + input = paddle.to_tensor(input_np) + + out, mask = max_pool2d( + input, + kernel_size=3, + stride=1, + padding=1, + dilation=[2, 2], + return_mask=True, + ) + + out_np, mask_np = max_pool2d_with_dilations_and_index_forward_naive( + input_np, + ksize=[3, 3], + strides=[1, 1], + paddings=[1, 1], + dilations=[2, 2], + global_pool=False, + ) + + np.testing.assert_allclose(out.numpy(), out_np, rtol=1e-05) + np.testing.assert_allclose(mask.numpy(), mask_np, rtol=1e-05) + def test_pool2d(self): for place in self.places: self.check_max_dygraph_results(place) @@ -823,6 +911,8 @@ def test_pool2d(self): self.check_lp_dygraph_results_norm_type_is_negative_inf(place) self.check_lp_dygraph_float64(place) self.check_lp_dygraph_float16(place) + self.check_max_dygraph_with_dilations(place) + self.check_max_dygraph_with_dilations_and_index(place) class TestPool2DError_API(unittest.TestCase): @@ -1084,6 +1174,23 @@ def run_zero_norm_type(): self.assertRaises(ValueError, run_zero_norm_type) + def run_invalid_dilation(): + with base.dygraph.guard(): + input_np = np.random.uniform(-1, 1, [2, 3, 32, 32]).astype( + np.float32 + ) + input_pd = paddle.to_tensor(input_np) + dilation = [-1, 1] + res_pd = max_pool2d( + input_pd, + kernel_size=2, + stride=2, + padding=0, + dilation=dilation, + ) + + self.assertRaises(ValueError, run_invalid_dilation) + class TestPool2D_API_ZeroSize(unittest.TestCase): def setUp(self): diff --git a/test/legacy_test/test_pool2d_op.py b/test/legacy_test/test_pool2d_op.py index 61a39f62df54a3..e13892db5190b2 100644 --- a/test/legacy_test/test_pool2d_op.py +++ b/test/legacy_test/test_pool2d_op.py @@ -35,6 +35,114 @@ def adaptive_end_index(index, input_size, output_size): return int(np.ceil((index + 1) * input_size / output_size)) +def normalize_paddings(paddings): + if len(paddings) == 2: + pad_top = pad_bottom = paddings[0] + pad_left = pad_right = paddings[1] + elif len(paddings) == 4: + pad_top, pad_bottom, pad_left, pad_right = paddings + else: + raise ValueError(f"paddings 必须为长度为 2 或 4,但收到 {paddings}") + return pad_top, pad_bottom, pad_left, pad_right + + +def max_pool2d_with_dilations_forward_naive( + x, + ksize, + strides, + paddings, + dilations, + global_pool=0, + ceil_mode=False, + data_format="NCHW", + padding_algorithm="EXPLICIT", +): + d_h, d_w = dilations + pad_top, pad_bottom, pad_left, pad_right = normalize_paddings(paddings) + + if data_format == "NCHW": + N, C, H, W = x.shape + axis_h, axis_w = 2, 3 + else: + N, H, W, C = x.shape + axis_h, axis_w = 1, 2 + + if global_pool == 1: + ksize = [H, W] + pad_top = pad_bottom = pad_left = pad_right = 0 + H_out = W_out = 1 + else: + if padding_algorithm == "SAME": + H_out = (H + strides[0] - 1) // strides[0] + W_out = (W + strides[1] - 1) // strides[1] + + pad_h_total = max((H_out - 1) * strides[0] + ksize[0] - H, 0) + pad_w_total = max((W_out - 1) * strides[1] + ksize[1] - W, 0) + + pad_top = pad_h_total // 2 + pad_bottom = pad_h_total - pad_top + pad_left = pad_w_total // 2 + pad_right = pad_w_total - pad_left + + if padding_algorithm == "VALID": + pad_top = pad_bottom = pad_left = pad_right = 0 + + if padding_algorithm == "VALID": + if ceil_mode: + H_out = (H - ksize[0] + strides[0] - 1) // strides[0] + 1 + W_out = (W - ksize[1] + strides[1] - 1) // strides[1] + 1 + else: + H_out = (H - ksize[0]) // strides[0] + 1 + W_out = (W - ksize[1]) // strides[1] + 1 + else: + if ceil_mode: + H_out = ( + H - ksize[0] + pad_top + pad_bottom + strides[0] - 1 + ) // strides[0] + 1 + W_out = ( + W - ksize[1] + pad_left + pad_right + strides[1] - 1 + ) // strides[1] + 1 + else: + H_out = (H - ksize[0] + pad_top + pad_bottom) // strides[0] + 1 + W_out = (W - ksize[1] + pad_left + pad_right) // strides[1] + 1 + + if data_format == "NCHW": + out = np.zeros((N, C, H_out, W_out), dtype=x.dtype) + else: + out = np.zeros((N, H_out, W_out, C), dtype=x.dtype) + + for oh in range(H_out): + base_h = oh * strides[0] - pad_top + for ow in range(W_out): + base_w = ow * strides[1] - pad_left + + vals = [] + for kh in range(ksize[0]): + h = base_h + kh * d_h + if 0 <= h < H: + for kw in range(ksize[1]): + w = base_w + kw * d_w + if 0 <= w < W: + if data_format == "NCHW": + vals.append(x[:, :, h, w]) + else: + vals.append(x[:, h, w, :]) + + if len(vals) == 0: + if data_format == "NCHW": + out[:, :, oh, ow] = -np.inf + else: + out[:, oh, ow, :] = -np.inf + else: + stacked = np.max(np.stack(vals, axis=-1), axis=-1) + if data_format == "NCHW": + out[:, :, oh, ow] = stacked + else: + out[:, oh, ow, :] = stacked + + return out + + def max_pool2D_forward_naive( x, ksize, @@ -373,6 +481,61 @@ def pool2d_wrapper_use_cudnn( ) +def pool2d_wrapper_not_use_cudnn_with_dilations( + X, + ksize=[], + strides=[], + paddings=[], + dilations=[], + ceil_mode=False, + data_format="NCDHW", + global_pooling=False, + padding_algorithm="EXPLICIT", +): + if in_dynamic_mode(): + X = X._use_gpudnn(False) + if data_format == "AnyLayout": + data_format = "NCDHW" + return paddle._C_ops.max_pool2d_with_dilations( + X, + ksize, + strides, + paddings, + dilations, + ceil_mode, + data_format, + global_pooling, + padding_algorithm, + ) + + +def pool2d_wrapper_use_cudnn_with_dilations( + X, + ksize=[], + strides=[], + paddings=[], + dilations=[], + ceil_mode=False, + data_format="NCDHW", + global_pooling=False, + padding_algorithm="EXPLICIT", +): + if data_format == "AnyLayout": + data_format = "NCDHW" + + return paddle._C_ops.max_pool2d_with_dilations( + X, + ksize, + strides, + paddings, + dilations, + ceil_mode, + data_format, + global_pooling, + padding_algorithm, + ) + + def lp_pool2d_wrapper( X, ksize=[], @@ -720,6 +883,7 @@ def init_kernel_type(self): create_test_cudnn_class(TestCase4) create_test_cudnn_class(TestCase5) + # --------------------test pool2d cudnn_fp16-------------------- @@ -868,6 +1032,8 @@ def test_check_grad(self): create_test_bf16_class(TestCase3) create_test_bf16_class(TestCase4) create_test_bf16_class(TestCase5) + + # --------------------test pool2d use ceil mode-------------------- @@ -1078,6 +1244,8 @@ def init_shape(self): # ----------- test channel_last -------------- + + class TestPool2D_channel_last(TestPool2D_Op): def init_data_format(self): self.data_format = "NHWC" @@ -1331,7 +1499,7 @@ def init_shape(self): self.shape = [2, 7, 7, 3] -# test paddings: SAME VALID +# --------------------test paddings: SAME VALID-------------------- def create_test_padding_SAME_class(parent): @@ -1466,5 +1634,288 @@ def init_shape(self): create_test_cudnn_padding_SAME_class(TestCase1_strides) +class TestMax_Pool2D_With_Dilations(TestPool2D_Op): + def setUp(self): + self.op_type = "max_pool2d_with_dilations" + self.use_cudnn = False + self.use_onednn = False + self.init_data_type() + self.init_test_case() + self.padding_algorithm = "EXPLICIT" + self.init_paddings() + self.init_dilations() + self.init_global_pool() + self.init_ceil_mode() + self.init_data_format() + self.init_shape() + self.init_pool_type() + + if self.is_bfloat16_op(): + input = np.random.random(self.shape).astype(np.float32) + else: + input = np.random.random(self.shape).astype(self.dtype) + + output = max_pool2d_with_dilations_forward_naive( + input, + self.ksize, + self.strides, + self.paddings, + self.dilations, + self.global_pool, + self.ceil_mode, + self.data_format, + self.padding_algorithm, + ) + + if self.is_bfloat16_op(): + output = convert_float_to_uint16(output) + self.inputs = {'X': convert_float_to_uint16(input)} + else: + output = output.astype(self.dtype) + self.inputs = {'X': OpTest.np_dtype_to_base_dtype(input)} + + self.outputs = {'Out': output} + + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'dilations': self.dilations, + 'ksize': self.ksize, + 'global_pooling': self.global_pool, + 'use_cudnn': self.use_cudnn, + 'use_onednn': self.use_onednn, + 'ceil_mode': self.ceil_mode, + 'data_format': self.data_format, + "padding_algorithm": self.padding_algorithm, + } + if self.use_cudnn: + self.python_api = pool2d_wrapper_use_cudnn_with_dilations + else: + self.python_api = pool2d_wrapper_not_use_cudnn_with_dilations + + def init_dilations(self): + self.dilations = [2, 2] + + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + + def init_paddings(self): + self.paddings = [1, 1] + + def init_global_pool(self): + self.global_pool = False + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + def init_ceil_mode(self): + self.ceil_mode = False + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad( + {'X'}, + 'Out', + max_relative_error=1.00, + check_cinn=True, + check_pir=True, + ) + + +class TestMax_Pool2D_With_Dilations_Channel_Last(TestMax_Pool2D_With_Dilations): + def init_paddings(self): + self.paddings = [1, 2, 1, 2] + + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 7, 7, 3] + + +class TestMax_Pool2D_With_Dilations_Global_Pool(TestMax_Pool2D_With_Dilations): + def init_global_pool(self): + self.global_pool = True + + +class TestMax_Pool2D_With_Dilations_Empty_Input(TestMax_Pool2D_With_Dilations): + def init_shape(self): + self.shape = [0, 7, 7, 3] + + +class TestMax_Pool2D_With_Dilations_One_Dilation(TestMax_Pool2D_With_Dilations): + def init_dilations(self): + self.dilations = [1, 1] + + +def create_test_bf16_class_v2(parent, check_grad=True): + @unittest.skipIf( + not (core.is_compiled_with_cuda() or is_custom_device()), + "core is not compiled with CUDA", + ) + class TestBf16Case(parent): + def init_kernel_type(self): + self.use_cuda = True + self.dtype = np.uint16 + + def test_check_output(self): + if core.is_compiled_with_cuda() or is_custom_device(): + place = get_device_place() + self.check_output_with_place( + place, + check_dygraph=(not self.use_onednn), + check_cinn=True, + check_pir_onednn=self.check_pir_onednn, + ) + + def test_check_grad(self): + pass + + cls_name = "{}_{}".format(parent.__name__, "Bf16Op") + TestBf16Case.__name__ = cls_name + globals()[cls_name] = TestBf16Case + + +def create_test_cudnn_fp16_class_v2(parent, check_grad=True): + @unittest.skipIf( + not (core.is_compiled_with_cuda() or is_custom_device()), + "core is not compiled with CUDA", + ) + class TestCUDNNFp16Case(parent): + def init_kernel_type(self): + self.use_cudnn = True + self.dtype = np.float16 + + def test_check_output(self): + # TODO(wangzhongpu): support onednn op in dygraph mode + if core.is_compiled_with_cuda() or is_custom_device(): + place = get_device_place() + if core.is_float16_supported(place): + self.check_output_with_place( + place, + check_dygraph=(not self.use_onednn), + check_cinn=True, + check_pir_onednn=self.check_pir_onednn, + ) + + def test_check_grad(self): + pass + + cls_name = "{}_{}".format(parent.__name__, "CUDNNFp16Op") + TestCUDNNFp16Case.__name__ = cls_name + globals()[cls_name] = TestCUDNNFp16Case + + +def create_test_fp16_class_v2(parent, check_grad=True): + @unittest.skipIf( + not (core.is_compiled_with_cuda() or is_custom_device()), + "core is not compiled with CUDA", + ) + class TestFp16Case(parent): + def init_kernel_type(self): + self.use_cudnn = False + self.dtype = np.float16 + + def test_check_output(self): + # TODO(wangzhongpu): support onednn op in dygraph mode + if core.is_compiled_with_cuda() or is_custom_device(): + place = get_device_place() + if core.is_float16_supported(place): + self.check_output_with_place( + place, + check_dygraph=(not self.use_onednn), + check_cinn=True, + check_pir_onednn=self.check_pir_onednn, + ) + + def test_check_grad(self): + pass + + cls_name = "{}_{}".format(parent.__name__, "Fp16Op") + TestFp16Case.__name__ = cls_name + globals()[cls_name] = TestFp16Case + + +def create_test_cpu_class(parent): + class TestMaxPool2dCPU(parent): + def test_check_output(self): + self.check_output_with_place( + paddle.base.CPUPlace(), check_pir=True, check_cinn=True + ) + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + paddle.base.CPUPlace(), + {'X'}, + 'Out', + max_relative_error=1.00, + check_cinn=True, + check_pir=True, + ) + + cls_name = "{}_{}".format(parent.__name__, "CPU") + TestMaxPool2dCPU.__name__ = cls_name + globals()[cls_name] = TestMaxPool2dCPU + + +create_test_cudnn_class(TestMax_Pool2D_With_Dilations) +create_test_fp16_class_v2(TestMax_Pool2D_With_Dilations) +create_test_bf16_class_v2(TestMax_Pool2D_With_Dilations) +create_test_cudnn_fp16_class_v2(TestMax_Pool2D_With_Dilations) +create_test_cudnn_use_ceil_class(TestMax_Pool2D_With_Dilations) +create_test_use_ceil_class(TestMax_Pool2D_With_Dilations) +create_test_padding_SAME_class(TestMax_Pool2D_With_Dilations) +create_test_padding_VALID_class(TestMax_Pool2D_With_Dilations) +create_test_cpu_class(TestMax_Pool2D_With_Dilations) + + +create_test_cudnn_class(TestMax_Pool2D_With_Dilations_Channel_Last) +create_test_fp16_class_v2(TestMax_Pool2D_With_Dilations_Channel_Last) +create_test_bf16_class_v2(TestMax_Pool2D_With_Dilations_Channel_Last) +create_test_cudnn_fp16_class_v2(TestMax_Pool2D_With_Dilations_Channel_Last) +create_test_cudnn_use_ceil_class(TestMax_Pool2D_With_Dilations_Channel_Last) +create_test_use_ceil_class(TestMax_Pool2D_With_Dilations_Channel_Last) +create_test_padding_SAME_class(TestMax_Pool2D_With_Dilations_Channel_Last) +create_test_padding_VALID_class(TestMax_Pool2D_With_Dilations_Channel_Last) +create_test_cpu_class(TestMax_Pool2D_With_Dilations_Channel_Last) + + +create_test_cudnn_class(TestMax_Pool2D_With_Dilations_Global_Pool) +create_test_fp16_class_v2(TestMax_Pool2D_With_Dilations_Global_Pool) +create_test_bf16_class_v2(TestMax_Pool2D_With_Dilations_Global_Pool) +create_test_cudnn_fp16_class_v2(TestMax_Pool2D_With_Dilations_Global_Pool) +create_test_cudnn_use_ceil_class(TestMax_Pool2D_With_Dilations_Global_Pool) +create_test_use_ceil_class(TestMax_Pool2D_With_Dilations_Global_Pool) +create_test_padding_SAME_class(TestMax_Pool2D_With_Dilations_Global_Pool) +create_test_padding_VALID_class(TestMax_Pool2D_With_Dilations_Global_Pool) +create_test_cpu_class(TestMax_Pool2D_With_Dilations_Global_Pool) + + +create_test_cudnn_class(TestMax_Pool2D_With_Dilations_Empty_Input) +create_test_fp16_class_v2(TestMax_Pool2D_With_Dilations_Empty_Input) +create_test_bf16_class_v2(TestMax_Pool2D_With_Dilations_Empty_Input) +create_test_cudnn_fp16_class_v2(TestMax_Pool2D_With_Dilations_Empty_Input) +create_test_cudnn_use_ceil_class(TestMax_Pool2D_With_Dilations_Empty_Input) +create_test_use_ceil_class(TestMax_Pool2D_With_Dilations_Empty_Input) +create_test_padding_SAME_class(TestMax_Pool2D_With_Dilations_Empty_Input) +create_test_padding_VALID_class(TestMax_Pool2D_With_Dilations_Empty_Input) +create_test_cpu_class(TestMax_Pool2D_With_Dilations_Empty_Input) + + +create_test_cudnn_class(TestMax_Pool2D_With_Dilations_One_Dilation) +create_test_fp16_class_v2(TestMax_Pool2D_With_Dilations_One_Dilation) +create_test_bf16_class_v2(TestMax_Pool2D_With_Dilations_One_Dilation) +create_test_cudnn_fp16_class_v2(TestMax_Pool2D_With_Dilations_One_Dilation) +create_test_cudnn_use_ceil_class(TestMax_Pool2D_With_Dilations_One_Dilation) +create_test_use_ceil_class(TestMax_Pool2D_With_Dilations_One_Dilation) +create_test_padding_SAME_class(TestMax_Pool2D_With_Dilations_One_Dilation) +create_test_padding_VALID_class(TestMax_Pool2D_With_Dilations_One_Dilation) +create_test_cpu_class(TestMax_Pool2D_With_Dilations_One_Dilation) + + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_pool_max_op.py b/test/legacy_test/test_pool_max_op.py index b03e6ebbd199f0..dc9ec3ce8418fc 100644 --- a/test/legacy_test/test_pool_max_op.py +++ b/test/legacy_test/test_pool_max_op.py @@ -138,6 +138,61 @@ def max_pool2D_forward_naive( return out, mask +def max_pool2d_with_dilations_and_index_forward_naive( + x, ksize, strides, paddings, dilations, global_pool=False +): + N, C, H, W = x.shape + kh, kw = ksize + sh, sw = strides + dh, dw = dilations + + if global_pool: + kh, kw = H, W + paddings = [0, 0, 0, 0] + + if len(paddings) == 2: + pad_top = pad_bottom = paddings[0] + pad_left = pad_right = paddings[1] + else: + pad_top, pad_bottom, pad_left, pad_right = paddings + + H_out = (H + pad_top + pad_bottom - kh) // sh + 1 + W_out = (W + pad_left + pad_right - kw) // sw + 1 + + out = np.zeros((N, C, H_out, W_out), dtype=x.dtype) + mask = np.zeros((N, C, H_out, W_out), dtype=np.int32) + + for oh in range(H_out): + for ow in range(W_out): + h_start = oh * sh - pad_top + w_start = ow * sw - pad_left + + for n in range(N): + for c in range(C): + max_val = -np.inf + max_idx = -1 + + for ih in range(kh): + h = h_start + ih * dh + if h < 0 or h >= H: + continue + + for iw in range(kw): + w = w_start + iw * dw + if w < 0 or w >= W: + continue + + v = x[n, c, h, w] + if v > max_val: + max_val = v + max_idx = h * W + w + + out[n, c, oh, ow] = max_val + mask[n, c, oh, ow] = max_idx + + return out, mask + + def max_pool3d_with_index_wrapper( x, kernel_size=[], @@ -344,16 +399,86 @@ def max_pool2d_with_index_wrapper( kernel_size=[], strides=[], paddings=[], + dilations=[], global_pooling=False, adaptive=False, ceil_mode=False, ): return paddle._C_ops.max_pool2d_with_index( - x, kernel_size, strides, paddings, global_pooling, adaptive, ceil_mode + x, + kernel_size, + strides, + paddings, + dilations, + global_pooling, + adaptive, + ceil_mode, ) class TestCase4(TestMaxPoolWithIndex_Op): + def setUp(self): + self.init_test_case() + self.init_global() + self.init_adaptive() + self.init_dtype() + + if self.is_bfloat16_op(): + input = np.random.random(self.shape).astype(np.float32) + input = convert_uint16_to_float( + convert_float_to_uint16(np.round(input * 100.0, 2)) + ) + + else: + input = np.random.random(self.shape).astype(self.dtype) + input = np.round(input * 100.0, 2) + + if self.dilations[0] > 1 or self.dilations[1] > 1: + output, mask = self.pool_forward_naive( + input, + self.ksize, + self.strides, + self.paddings, + self.dilations, + self.global_pool, + ) + else: + output, mask = self.pool_forward_naive( + input, + self.ksize, + self.strides, + self.paddings, + self.global_pool, + self.adaptive, + ) + mask = mask.astype("int32") + if self.is_bfloat16_op(): + output = output.astype(np.float32) + else: + output = output.astype(self.dtype) + + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + "dilations": self.dilations, + 'ksize': self.ksize, + 'global_pooling': self.global_pool, + 'adaptive': self.adaptive, + 'ceil_mode': False, + } + + if self.is_bfloat16_op(): + self.inputs = {'X': convert_float_to_uint16(input)} + self.outputs = { + 'Out': convert_float_to_uint16(output), + "Mask": mask, + } + self.inputs_fp32 = {'X': input} + + else: + self.inputs = {'X': input} + self.outputs = {'Out': output, "Mask": mask} + def init_test_case(self): self.op_type = "max_pool2d_with_index" self.python_api = max_pool2d_with_index_wrapper @@ -362,6 +487,7 @@ def init_test_case(self): self.ksize = [3, 3] self.strides = [1, 1] self.paddings = [1, 1] + self.dilations = [1, 1] def init_global(self): self.global_pool = True @@ -372,7 +498,7 @@ def init_global(self): self.global_pool = False -class TestCase6(TestMaxPoolWithIndex_Op): +class TestCase6(TestCase4): def init_test_case(self): self.op_type = "max_pool2d_with_index" self.python_api = max_pool2d_with_index_wrapper @@ -381,9 +507,7 @@ def init_test_case(self): self.ksize = [3, 3] self.strides = [2, 2] self.paddings = [0, 0] - - def init_global(self): - self.global_pool = True + self.dilations = [1, 1] class TestCase7(TestCase6): @@ -396,6 +520,65 @@ def init_adaptive(self): self.adaptive = True +class TestDilationsCase1(TestCase4): + def init_test_case(self): + self.op_type = "max_pool2d_with_index" + self.python_api = max_pool2d_with_index_wrapper + self.pool_forward_naive = ( + max_pool2d_with_dilations_and_index_forward_naive + ) + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + self.dilations = [2, 2] + + +class TestDilationsCase2(TestDilationsCase1): + def init_global(self): + self.global_pool = False + + +class TestDilationsCase3(TestDilationsCase2): + def init_test_case(self): + self.op_type = "max_pool2d_with_index" + self.python_api = max_pool2d_with_index_wrapper + self.pool_forward_naive = ( + max_pool2d_with_dilations_and_index_forward_naive + ) + self.shape = [0, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [0, 0] + self.dilations = [2, 2] + + +# ----------------max_pool2d_with_cpu_place---------------- +def create_test_cpu_class(parent): + class TestMaxPool2dCPU(parent): + def test_check_output(self): + self.check_output_with_place( + paddle.base.CPUPlace(), check_pir=True, check_cinn=True + ) + + def test_check_grad(self): + self.check_grad_with_place(paddle.base.CPUPlace(), {'X'}, ['Out']) + + cls_name = "{}_{}".format(parent.__name__, "CPU") + TestMaxPool2dCPU.__name__ = cls_name + globals()[cls_name] = TestMaxPool2dCPU + + +create_test_cpu_class(TestCase4) +create_test_cpu_class(TestCase5) +create_test_cpu_class(TestCase6) +create_test_cpu_class(TestCase7) +create_test_cpu_class(TestCastAdaptive2d) +create_test_cpu_class(TestDilationsCase1) +create_test_cpu_class(TestDilationsCase2) +create_test_cpu_class(TestDilationsCase3) + + # ----------------max_pool2d_with_index_fp16---------------- def create_test_fp16_class(parent): @unittest.skipIf( @@ -427,6 +610,8 @@ def test_check_grad(self): create_test_fp16_class(TestCase6) create_test_fp16_class(TestCase7) create_test_fp16_class(TestCastAdaptive2d) +create_test_fp16_class(TestDilationsCase1) +create_test_fp16_class(TestDilationsCase2) # ----------------max_pool2d_with_index_bf16---------------- @@ -473,6 +658,8 @@ def test_check_grad(self): create_test_bf16_class(TestCase6) create_test_bf16_class(TestCase7) create_test_bf16_class(TestCastAdaptive2d) +create_test_bf16_class(TestDilationsCase1) +create_test_bf16_class(TestDilationsCase2) def skip_unit_test(): diff --git a/test/white_list/op_accuracy_white_list.py b/test/white_list/op_accuracy_white_list.py index 2fc3cea9ad5fb5..31b6fe2d64b5aa 100644 --- a/test/white_list/op_accuracy_white_list.py +++ b/test/white_list/op_accuracy_white_list.py @@ -44,6 +44,7 @@ 'max_pool2d_v2', 'max_pool2d_with_index', 'max_pool3d_with_index', + 'max_pool2d_with_dilations', 'fractional_max_pool2d', 'fractional_max_pool3d', 'minus', diff --git a/test/white_list/op_threshold_white_list.py b/test/white_list/op_threshold_white_list.py index f7e888a3615bbf..f8ac86b6640fd6 100644 --- a/test/white_list/op_threshold_white_list.py +++ b/test/white_list/op_threshold_white_list.py @@ -28,6 +28,7 @@ 'kldiv_loss', 'lstm', 'max_pool2d_with_index', + 'max_pool2d_with_dilations', 'max_pool3d_with_index', 'fractional_max_pool2d', 'fractional_max_pool3d',