From 2795262e631ab81f5b8fe02232111f84bc7c3f5b Mon Sep 17 00:00:00 2001 From: SigureMo Date: Tue, 16 Dec 2025 21:25:00 +0800 Subject: [PATCH 1/7] [PIR][2/N] Support register Python function on PIR (`NativeMetaTensor` and lower part) --- .../pir/transforms/pd_op_to_kernel_pass.cc | 113 +++++++++++++++++- paddle/fluid/pybind/CMakeLists.txt | 1 + paddle/fluid/pybind/ir_meta_tensor.cc | 85 ------------- paddle/fluid/pybind/ir_tensor.cc | 91 -------------- paddle/fluid/pybind/ir_tensor.h | 22 ---- paddle/fluid/pybind/native_meta_tensor.cc | 109 +++++++++++++++++ ...{ir_meta_tensor.h => native_meta_tensor.h} | 2 +- paddle/fluid/pybind/pybind.cc | 6 +- python/paddle/static/__init__.py | 2 +- python/paddle/static/meta_tensor.py | 43 ------- 10 files changed, 226 insertions(+), 248 deletions(-) delete mode 100644 paddle/fluid/pybind/ir_meta_tensor.cc delete mode 100644 paddle/fluid/pybind/ir_tensor.cc delete mode 100644 paddle/fluid/pybind/ir_tensor.h create mode 100644 paddle/fluid/pybind/native_meta_tensor.cc rename paddle/fluid/pybind/{ir_meta_tensor.h => native_meta_tensor.h} (94%) delete mode 100644 python/paddle/static/meta_tensor.py diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index de172d07a4cb57..6c8362aa39a055 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -2505,6 +2505,105 @@ void HandleForCustomOp( block->push_back(op); } +void HandleForPythonOp( + pir::IrContext* ctx, + pir::Operation* op_item, + const phi::KernelKey& kernel_key, + const phi::Place place, + const OpYamlInfoParser* op_info_parser, + std::unordered_map* map_op_pair, + std::unordered_map* map_value_pair, + pir::Block* block) { + // Prepare output + std::vector op_output_types; + for (size_t i = 0; i < op_item->num_results(); ++i) { + phi::Place out_place = phi::TransToPhiPlace(kernel_key.backend()); + PushBackOutputTypes(ctx, + op_item, + op_item->result(i).type(), + out_place, + kernel_key, + &op_output_types); + } + + // Prepare input + std::vector vec_inputs; + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + PADDLE_ENFORCE_EQ( + map_value_pair->count(cur_in), + true, + common::errors::PreconditionNotMet( + "[%d]'s input of [%s] op MUST in map pair", i, op_item->name())); + + auto new_in = map_value_pair->at(cur_in); + auto new_in_type = new_in.type(); + + if (new_in_type.isa()) { + auto in_place = new_in_type.dyn_cast().place(); + // need trans from GPU_PINNED to GPU, refer to PR#41972 + if (phi::AllocationType::GPUPINNED == place.GetType()) { + // build memcopy op + auto out_place = phi::TransToPhiPlace(phi::Backend::GPU); + auto new_in_alloc_type = + new_in_type.dyn_cast(); + auto out_type = + AllocatedDenseTensorType::get(ctx, + out_place, + new_in_alloc_type.dtype(), + new_in_alloc_type.dims(), + new_in_alloc_type.data_layout(), + new_in_alloc_type.lod(), + new_in_alloc_type.offset()); + new_in = AddPlaceTransferOp( + new_in, out_type, in_place, out_place, kernel_key, block); + } + } + + vec_inputs.push_back(new_in); + } + + // Prepare attr + std::unordered_map op_attribute{ + {"op_name", pir::StrAttribute::get(ctx, op_item->name())}, + {"kernel_name", pir::StrAttribute::get(ctx, op_item->name())}, + {"kernel_key", KernelAttribute::get(ctx, kernel_key)}}; + + auto op_attr_map = op_item->attributes(); + for (auto& map_item : op_attr_map) { + op_attribute.emplace(map_item.first, map_item.second); + } + if (op_item->HasTrait()) { + op_attribute.emplace("is_inplace", pir::BoolAttribute::get(ctx, true)); + } + op_attribute.emplace("origin_id", + pir::Int64Attribute::get(ctx, op_item->id())); + + VLOG(6) << "Lower custom pyop: " << op_item->name() + << " to : " << PythonFunctionOp::name(); + + pir::OpInfo py_func_op_info = + ctx->GetRegisteredOpInfo(PythonFunctionOp::name()); + + pir::Operation* op = nullptr; + op = pir::Operation::Create( + vec_inputs, op_attribute, op_output_types, py_func_op_info); + op->set_attribute("origin_id", pir::Int64Attribute::get(ctx, op->id())); + + (*map_op_pair)[op_item] = op; + + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + (*map_value_pair)[op_item->result(i)] = op->result(i); + } + } + block->push_back(op); +} + void HandleForTensorRTOp( pir::IrContext* ctx, pir::Operation* op_item, @@ -3588,7 +3687,6 @@ void ProcessBlock( auto kernel_name = GetKernelName(op_info_parser.get(), op_item); auto kernel_key = GetKernelKey( op_item, place, kernel_name, *map_value_pair, op_info_parser.get()); - VLOG(6) << "kernel type " << kernel_key; if (paddle::dialect::IsCustomOp(op_item)) { HandleForCustomOp(ctx, @@ -3602,6 +3700,18 @@ void ProcessBlock( continue; } + if (paddle::dialect::IsCustomPyOp(op_item)) { + HandleForPythonOp(ctx, + op_item, + kernel_key, + place, + op_info_parser.get(), + map_op_pair, + map_value_pair, + new_block); + continue; + } + if (paddle::dialect::IsTensorRTOp(op_item)) { HandleForTensorRTOp(ctx, op_item, @@ -3714,6 +3824,7 @@ std::unique_ptr PdOpLowerToKernelPass(pir::Program* prog, ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); #ifdef PADDLE_WITH_DNNL ctx->GetOrRegisterDialect(); diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 3e99ab2744ef10..d9bd49e164204f 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -115,6 +115,7 @@ set(PYBIND_SRCS graph.cc ir_tensor.cc ir_meta_tensor.cc + native_meta_tensor.cc reader_py.cc protobuf.cc exception.cc diff --git a/paddle/fluid/pybind/ir_meta_tensor.cc b/paddle/fluid/pybind/ir_meta_tensor.cc deleted file mode 100644 index 9af97134a9dc3a..00000000000000 --- a/paddle/fluid/pybind/ir_meta_tensor.cc +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/pir/dialect/operator/ir/ir_meta_tensor.h" -#include "paddle/common/ddim.h" -#include "paddle/fluid/pybind/ir_meta_tensor.h" -#include "paddle/phi/core/tensor_base.h" -#include "pybind11/functional.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" - -namespace paddle::pybind { - -using IrMetaTensor = paddle::dialect::IrMetaTensor; - -void BindIrMetaTensor(py::module* m) { - py::class_(*m, "IrMetaTensor") - .def(py::init([](const phi::TensorBase& tensor, - const bool strided_kernel_used) { - return IrMetaTensor(tensor, strided_kernel_used); - }), - py::arg("tensor"), - py::arg("strided_kernel_used") = false) - .def( - "set_shape", - [](IrMetaTensor& self, const std::vector& dims) { - phi::DDim ddim = phi::make_ddim(dims); - self.set_dims(ddim); - }, - "Set tensor dimensions") - .def( - "set_dtype", - [](IrMetaTensor& self, const std::string& dtype_str) { - self.set_dtype(phi::StringToDataType(dtype_str)); - }, - "Set tensor data type from string") - .def( - "set_dtype", - [](IrMetaTensor& self, const phi::DataType& dtype) { - self.set_dtype(dtype); - }, - "Set tensor data type from DataType object") - .def_property_readonly( - "dtype", - [](const IrMetaTensor& self) -> phi::DataType { - return self.dtype(); - }, - "Get tensor data type") - .def_property_readonly( - "shape", - [](const IrMetaTensor& self) -> std::vector { - const phi::DDim& dims = self.dims(); - return common::vectorize(dims); - }, - "Get tensor shape") - .def("__repr__", [](const IrMetaTensor& self) { - const phi::DDim& dims = self.dims(); - std::ostringstream shape_ss; - shape_ss << "["; - for (int i = 0; i < dims.size(); ++i) { - if (i > 0) { - shape_ss << ", "; - } - shape_ss << dims[i]; - } - shape_ss << "]"; - std::string dtype_str = phi::DataTypeToString(self.dtype()); - return "IrMetaTensor(shape=" + shape_ss.str() + ", dtype=" + dtype_str + - ")"; - }); -} -} // namespace paddle::pybind diff --git a/paddle/fluid/pybind/ir_tensor.cc b/paddle/fluid/pybind/ir_tensor.cc deleted file mode 100644 index 2777e63cd2b598..00000000000000 --- a/paddle/fluid/pybind/ir_tensor.cc +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/pir/dialect/operator/ir/ir_tensor.h" -#include "paddle/common/ddim.h" -#include "paddle/fluid/pybind/ir_tensor.h" -#include "paddle/phi/core/tensor_base.h" -#include "pybind11/functional.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" - -namespace paddle::pybind { - -using IrTensor = paddle::dialect::IrTensor; - -void BindIrTensor(py::module* m) { - // IrTensor inherits from phi::TensorBase, - // pybind11 requires base type registration - py::class_(*m, "TensorBaseHolder"); - py::class_(*m, "IrTensor") - .def(py::init<>()) - .def(py::init()) - .def( - "copy", - [](const IrTensor& self) { return IrTensor(self); }, - "Create a deep copy of this tensor") - .def( - "set_shape", - [](IrTensor& self, const std::vector& dims) { - phi::DDim ddim = phi::make_ddim(dims); - self.SetDims(ddim); - }, - "Set tensor dimensions") - .def( - "set_dtype", - [](IrTensor& self, const std::string& dtype_str) { - self.SetDtype(phi::StringToDataType(dtype_str)); - }, - "Set tensor data type from string") - .def( - "set_dtype", - [](IrTensor& self, const phi::DataType& dtype) { - self.SetDtype(dtype); - }, - "Set tensor data type from DataType object") - .def_property_readonly( - "dtype", - [](const IrTensor& self) -> phi::DataType { return self.dtype(); }, - "Get tensor data type") - .def_property_readonly( - "shape", - [](const IrTensor& self) -> std::vector { - const phi::DDim& dims = self.dims(); - return common::vectorize(dims); - }, - "Get tensor shape") - .def("__eq__", - [](const IrTensor& self, const IrTensor& other) { - return self.dtype() == other.dtype() && - self.dims() == other.dims(); - }) - .def("__repr__", [](const IrTensor& self) { - const phi::DDim& dims = self.dims(); - std::ostringstream shape_ss; - shape_ss << "["; - for (int i = 0; i < dims.size(); ++i) { - if (i > 0) { - shape_ss << ", "; - } - shape_ss << dims[i]; - } - shape_ss << "]"; - std::string dtype_str = phi::DataTypeToString(self.dtype()); - return "IrTensor(shape=" + shape_ss.str() + ", dtype=" + dtype_str + - ")"; - }); -} -} // namespace paddle::pybind diff --git a/paddle/fluid/pybind/ir_tensor.h b/paddle/fluid/pybind/ir_tensor.h deleted file mode 100644 index 7704e2bf6c8f71..00000000000000 --- a/paddle/fluid/pybind/ir_tensor.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#pragma once - -#include "pybind11/pybind11.h" - -namespace py = pybind11; -namespace paddle { -namespace pybind { - -void BindIrTensor(py::module* m); - -} // namespace pybind -} // namespace paddle diff --git a/paddle/fluid/pybind/native_meta_tensor.cc b/paddle/fluid/pybind/native_meta_tensor.cc new file mode 100644 index 00000000000000..8f891e015d0d2c --- /dev/null +++ b/paddle/fluid/pybind/native_meta_tensor.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/api/ext/native_meta_tensor.h" +#include "paddle/fluid/pybind/native_meta_tensor.h" +#include "paddle/utils/pybind.h" +#include "pybind11/functional.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace paddle::pybind { + +void BindNativeMetaTensor(py::module* m) { + py::class_(*m, "NativeMetaTensor") + .def(py::init<>()) + .def(py::init()) + .def(py::init([](const py::object& dtype, const py::object& shape) { + phi::DataType dt = phi::DataType::FLOAT32; + if (!dtype.is_none()) { + dt = dtype.cast(); + } + std::vector dims; + if (py::isinstance(shape) || + py::isinstance(shape)) { + dims = shape.cast>(); + } else { + PADDLE_THROW(common::errors::InvalidArgument( + "The shape argument must be a list or tuple of integers " + "or None, but got %s.", + py::str(shape))); + } + return phi::NativeMetaTensor(dt, phi::make_ddim(dims)); + }), + py::arg("dtype") = py::none(), + py::arg("shape") = py::list()) + .def( + "copy", + [](const phi::NativeMetaTensor& self) { + return phi::NativeMetaTensor(self); + }, + "Create a deep copy of this tensor") + .def( + "set_shape", + [](phi::NativeMetaTensor& self, const std::vector& dims) { + phi::DDim ddim = phi::make_ddim(dims); + self.set_dims(ddim); + }, + "Set tensor dimensions") + .def( + "set_dtype", + [](phi::NativeMetaTensor& self, const std::string& dtype_str) { + self.set_dtype(phi::StringToDataType(dtype_str)); + }, + "Set tensor data type from string") + .def( + "set_dtype", + [](phi::NativeMetaTensor& self, const phi::DataType& dtype) { + self.set_dtype(dtype); + }, + "Set tensor data type from DataType object") + .def_property_readonly( + "dtype", + [](const phi::NativeMetaTensor& self) -> phi::DataType { + return self.dtype(); + }, + "Get tensor data type") + .def_property_readonly( + "shape", + [](const phi::NativeMetaTensor& self) -> std::vector { + const phi::DDim& dims = self.dims(); + return common::vectorize(dims); + }, + "Get tensor shape") + .def("__eq__", + [](const phi::NativeMetaTensor& self, + const phi::NativeMetaTensor& other) { + return self.dtype() == other.dtype() && + self.dims() == other.dims(); + }) + .def("__repr__", [](const phi::NativeMetaTensor& self) { + const phi::DDim& dims = self.dims(); + std::ostringstream shape_ss; + shape_ss << "["; + for (int i = 0; i < dims.size(); ++i) { + if (i > 0) { + shape_ss << ", "; + } + shape_ss << dims[i]; + } + shape_ss << "]"; + std::string dtype_str = phi::DataTypeToString(self.dtype()); + return "NativeMetaTensor(shape=" + shape_ss.str() + + ", dtype=" + dtype_str + ")"; + }); +} +} // namespace paddle::pybind diff --git a/paddle/fluid/pybind/ir_meta_tensor.h b/paddle/fluid/pybind/native_meta_tensor.h similarity index 94% rename from paddle/fluid/pybind/ir_meta_tensor.h rename to paddle/fluid/pybind/native_meta_tensor.h index f777477311e650..afcd1e7ecffe5e 100644 --- a/paddle/fluid/pybind/ir_meta_tensor.h +++ b/paddle/fluid/pybind/native_meta_tensor.h @@ -16,7 +16,7 @@ namespace py = pybind11; namespace paddle { namespace pybind { -void BindIrMetaTensor(py::module* m); +void BindNativeMetaTensor(py::module* m); } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 461fdda7b2e715..1bf479a1470269 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -125,10 +125,9 @@ limitations under the License. */ #include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/inference_api.h" #include "paddle/fluid/pybind/io.h" -#include "paddle/fluid/pybind/ir_meta_tensor.h" -#include "paddle/fluid/pybind/ir_tensor.h" #include "paddle/fluid/pybind/jit.h" #include "paddle/fluid/pybind/metrics_py.h" +#include "paddle/fluid/pybind/native_meta_tensor.h" #include "paddle/fluid/pybind/pir.h" #include "paddle/fluid/pybind/pybind_variant_caster.h" #include "paddle/fluid/pybind/python_callable_registry.h" @@ -1573,8 +1572,7 @@ PYBIND11_MODULE(libpaddle, m) { BindJit(&m); BindSot(&m); BindCustomDevicePy(&m); - BindIrTensor(&m); - BindIrMetaTensor(&m); + BindNativeMetaTensor(&m); BindEagerUtils(m.ptr()); BindOpFunctionCommon(m.ptr()); diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index b77f30cf86bf88..b625b0b80c13a8 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -38,6 +38,7 @@ set_ipu_shard, xpu_places, ) +from ..base.libpaddle import NativeMetaTensor as MetaTensor # noqa: F401 from ..base.param_attr import WeightNormParamAttr from ..tensor.creation import create_global_var, create_parameter from . import amp, nn # noqa: F401 @@ -64,7 +65,6 @@ serialize_program, set_program_state, ) -from .meta_tensor import MetaTensor # noqa: F401 from .nn.common import ExponentialMovingAverage, py_func from .nn.control_flow import Print from .nn.metric import accuracy, auc, ctr_metric_bundle diff --git a/python/paddle/static/meta_tensor.py b/python/paddle/static/meta_tensor.py deleted file mode 100644 index 1c96015deef308..00000000000000 --- a/python/paddle/static/meta_tensor.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ..base.libpaddle import IrMetaTensor, IrTensor - - -class MetaTensor: - def __init__(self, shape=[], dtype="float32"): - self.ir_tensor = IrTensor() - self.ir_meta_tensor = IrMetaTensor(self.ir_tensor) - self.ir_meta_tensor.set_shape(shape) - self.ir_meta_tensor.set_dtype(dtype) - - def set_shape(self, shape): - self.ir_meta_tensor.set_shape(shape) - - @property - def shape(self): - return self.ir_meta_tensor.shape - - def set_dtype(self, dtype): - self.ir_meta_tensor.set_dtype(dtype) - - @property - def dtype(self): - return self.ir_meta_tensor.dtype - - def __eq__(self, other): - return ( - self.ir_meta_tensor.dtype == other.ir_meta_tensor.dtype - and self.ir_meta_tensor.shape == other.ir_meta_tensor.shape - ) From 37a339386499c3f14b64d80619ae8897ff87d6aa Mon Sep 17 00:00:00 2001 From: SigureMo Date: Tue, 16 Dec 2025 21:29:59 +0800 Subject: [PATCH 2/7] add co-author Co-authored-by: Ryan <44900829+DrRyanHuang@users.noreply.github.com> From 680d68a5c8fba68299a74c52dfdbcbd6ad7ce973 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 17 Dec 2025 19:54:02 +0800 Subject: [PATCH 3/7] IsCustomPyOp -> IsPythonOp --- paddle/fluid/pir/dialect/operator/ir/op_dialect.h | 2 +- paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.h b/paddle/fluid/pir/dialect/operator/ir/op_dialect.h index 946b529071914e..1deb3b60f2498d 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.h @@ -48,7 +48,7 @@ inline bool IsCustomOp(pir::Operation* op) { return op_name.find("custom_op") != op_name.npos; } -inline bool IsCustomPyOp(pir::Operation* op) { +inline bool IsPythonOp(pir::Operation* op) { const std::string& op_name = op->name(); return op_name.find("py_op") != op_name.npos; } diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 6c8362aa39a055..decd418a4a0d5a 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -3700,7 +3700,7 @@ void ProcessBlock( continue; } - if (paddle::dialect::IsCustomPyOp(op_item)) { + if (paddle::dialect::IsPythonOp(op_item)) { HandleForPythonOp(ctx, op_item, kernel_key, From 43104dae0027e5751c35ece920e0810f571abcf8 Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Wed, 17 Dec 2025 20:12:14 +0800 Subject: [PATCH 4/7] Update paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc --- paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index decd418a4a0d5a..b54d6a8910256d 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -2583,7 +2583,7 @@ void HandleForPythonOp( op_attribute.emplace("origin_id", pir::Int64Attribute::get(ctx, op_item->id())); - VLOG(6) << "Lower custom pyop: " << op_item->name() + VLOG(6) << "Lower pyop: " << op_item->name() << " to : " << PythonFunctionOp::name(); pir::OpInfo py_func_op_info = From 2d0d37316ab76c6d07267ca34735a6f1f9c38ae8 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 17 Dec 2025 20:12:43 +0800 Subject: [PATCH 5/7] rm ir_tensor and ir_meta_tensor.cc --- paddle/fluid/pybind/CMakeLists.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index d9bd49e164204f..27b935cbfc697b 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -113,8 +113,6 @@ set(PYBIND_SRCS pir.cc pir_utils.cc graph.cc - ir_tensor.cc - ir_meta_tensor.cc native_meta_tensor.cc reader_py.cc protobuf.cc From 6db77bb865a7d48db72f8b5f5be7fffec460a687 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 17 Dec 2025 22:03:39 +0800 Subject: [PATCH 6/7] remove 2 cases --- test/legacy_test/test_meta_tensor.py | 44 ---------------------------- 1 file changed, 44 deletions(-) diff --git a/test/legacy_test/test_meta_tensor.py b/test/legacy_test/test_meta_tensor.py index f873f5cd6d9197..5897a3f37a8693 100644 --- a/test/legacy_test/test_meta_tensor.py +++ b/test/legacy_test/test_meta_tensor.py @@ -15,53 +15,9 @@ import unittest import paddle -from paddle.base.libpaddle import IrMetaTensor, IrTensor from paddle.static import MetaTensor -class TestIrTensor(unittest.TestCase): - def test_basic_get_set(self): - ir_tensor = IrTensor() - - ir_tensor.set_shape([4, 8192, 768]) - self.assertEqual(ir_tensor.shape, [4, 8192, 768]) - - ir_tensor.set_dtype('bfloat16') - self.assertEqual(ir_tensor.dtype, paddle.bfloat16) - ir_tensor.set_dtype(paddle.uint8) - self.assertEqual(ir_tensor.dtype, paddle.uint8) - - def test_eq(self): - x_ir_meta = IrTensor() - y_ir_meta = IrTensor() - self.assertEqual(x_ir_meta, y_ir_meta) - x_ir_meta.set_shape([4, 8192]) - y_ir_meta.set_shape([4, 8192]) - self.assertEqual(x_ir_meta, y_ir_meta) - x_ir_meta.set_shape([4, 8193]) - self.assertNotEqual(x_ir_meta, y_ir_meta) - y_ir_meta = IrTensor(x_ir_meta) - self.assertEqual(x_ir_meta, y_ir_meta) - - -class TestIrMetaTensor(unittest.TestCase): - def test_basic_get_set(self): - ir_tensor = IrTensor() - ir_meta_tensor = IrMetaTensor(ir_tensor) - - shape = [4, 8192, 768] - ir_meta_tensor.set_shape(shape) - self.assertEqual(ir_tensor.shape, shape) - self.assertEqual(ir_meta_tensor.shape, shape) - - ir_meta_tensor.set_dtype('bfloat16') - self.assertEqual(ir_tensor.dtype, paddle.bfloat16) - self.assertEqual(ir_meta_tensor.dtype, paddle.bfloat16) - ir_meta_tensor.set_dtype(paddle.uint8) - self.assertEqual(ir_tensor.dtype, paddle.uint8) - self.assertEqual(ir_meta_tensor.dtype, paddle.uint8) - - def infer_meta_fn(x_meta: MetaTensor, y_meta: MetaTensor): z_meta = MetaTensor() z_meta.set_shape([x_meta.shape[0], y_meta.shape[-1]]) From cdbfe73dead35fde5b72711de3e8219adac52152 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 17 Dec 2025 22:04:42 +0800 Subject: [PATCH 7/7] reorder proxy docs --- python/paddle/compat/proxy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/compat/proxy.py b/python/paddle/compat/proxy.py index e17d33b38598a3..d3b761a281530e 100644 --- a/python/paddle/compat/proxy.py +++ b/python/paddle/compat/proxy.py @@ -454,10 +454,10 @@ def enable_torch_proxy( PyTorch compat for. If None, enables PyTorch compat globally. Defaults to None. blocked_modules (str or Iterable[str], optional): Specific module or modules to exclude from PyTorch compat. Defaults to None. - silent (bool, optional): If True, suppresses warnings about scope changes. - Defaults to False. backend (str, optional): The backend to enable compat for. Currently only "torch" is supported. Defaults to "torch". + silent (bool, optional): If True, suppresses warnings about scope changes. + Defaults to False. Example: .. code-block:: pycon