Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
70a3ac0
add custom pyop IR
DrRyanHuang Dec 12, 2025
0a0c7d5
rename custom_pyop -> python_operator
DrRyanHuang Dec 15, 2025
4c7d96d
add all
DrRyanHuang Dec 15, 2025
e14ac1f
recover
DrRyanHuang Dec 15, 2025
a3101ed
rename some file && function
DrRyanHuang Dec 15, 2025
4ca7657
use NativeMetaTensor instead of IrTensor
SigureMo Dec 15, 2025
d926867
export `paddle/phi/core/ddim.h`
SigureMo Dec 16, 2025
b723138
rm MetaTensor
SigureMo Dec 16, 2025
88893b5
`PythonOperationFunctionInstruction` -> `PythonFunctionInstruction`
SigureMo Dec 16, 2025
2bfd3b7
2023 -> 2025
DrRyanHuang Dec 16, 2025
b8320a8
refine name in instructions
SigureMo Dec 16, 2025
a4432ed
custom py func -> python function
SigureMo Dec 16, 2025
5c7db9d
use void* and PointerAttribute
DrRyanHuang Dec 16, 2025
7287248
`HandleForPythonOp`
SigureMo Dec 16, 2025
cdb9201
refine comment
SigureMo Dec 16, 2025
a3ce468
Merge branch 'develop' into add_custom_pyop_IR
SigureMo Dec 16, 2025
a7ac4f3
normalize eager_utils
SigureMo Dec 17, 2025
18c9c2b
align part1 and part2
SigureMo Dec 17, 2025
ef81be2
Merge branch 'develop' into add_custom_pyop_IR
SigureMo Dec 17, 2025
395ea10
rm operator<< for vec<T>
SigureMo Dec 18, 2025
fd77e86
cleanup code
SigureMo Dec 18, 2025
2145b0d
cleanup comments
SigureMo Dec 18, 2025
b786b82
prepend indent to op attrs
SigureMo Dec 18, 2025
b9d3738
restore exc in pir_interpreter
SigureMo Dec 18, 2025
6819b90
cleanup some comments
SigureMo Dec 18, 2025
0a82cee
cleanup some code
SigureMo Dec 18, 2025
284deae
reformat msg
SigureMo Dec 18, 2025
c797f4d
cleanup infermeta in runtime
SigureMo Dec 18, 2025
b420902
normalize some python code
SigureMo Dec 18, 2025
6e0ee44
add ut
SigureMo Dec 18, 2025
05b7ca0
revert python/paddle/jit/sot/opcode_translator/executor/function_grap…
SigureMo Dec 18, 2025
744458e
Apply suggestions from code review
SigureMo Dec 18, 2025
bffe338
Apply suggestions from code review
SigureMo Dec 18, 2025
9fc149a
format code
SigureMo Dec 18, 2025
858f7b5
fix co_flags shadowing
SigureMo Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
rename custom_pyop -> python_operator
  • Loading branch information
DrRyanHuang committed Dec 15, 2025
commit 0a0c7d55ab53f900aa5f0f50a38e37b8621f6b51
2 changes: 1 addition & 1 deletion paddle/fluid/framework/CMakeLists.txt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

统一一下命名吧:

  • dialect name 为 py_op
  • 在文件名等一些非缩写场景,使用 python_operator,两者一起不缩写

避免和 custom op 相互之间冲突

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lower 后的 kernel dialect 叫啥合适呢?之前叫 custom_py_func,改成 python_operator_function ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

自定义算子的叫啥?

Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ configure_file(commit.h.in commit.h)

cc_library(
custom_operator
SRCS custom_operator.cc custom_pyoperator.cc
SRCS custom_operator.cc python_operator.cc
DEPS tensor
attribute
op_registry
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/custom_operator_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
constexpr char kCustomDialectPrefix[] = "custom_op."; // NOLINT
constexpr char kCustomPyDialectPrefix[] = "custom_pyop."; // NOLINT
constexpr char kPythonOperatorDialectPrefix[] = "py_op."; // NOLINT
constexpr char kGradSuffix[] = "_grad"; // NOLINT
constexpr char kDoubleGradSuffix[] = "_grad_grad"; // NOLINT

Expand Down Expand Up @@ -158,9 +158,9 @@ inline static const OpMetaInfo& GetOpInfoByPirName(
}
}

inline static const OpMetaInfo& GetCustomPyOpInfoByPirName(
inline static const OpMetaInfo& GetPythonOperatorInfoByPirName(
const std::string& pir_op_name) {
auto custom_name = pir_op_name.substr(strlen(kCustomPyDialectPrefix));
auto custom_name = pir_op_name.substr(strlen(kPythonOperatorDialectPrefix));
int pos = custom_name.length();

if (custom_name[pos - 1] == '_') {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,28 @@
#include <utility> // for std::move

#include "paddle/fluid/framework/custom_operator_utils.h"
#include "paddle/fluid/framework/custom_pyoperator.h"
#include "paddle/fluid/framework/python_operator.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/phi/api/ext/op_meta_info.h"

namespace paddle::framework {

void RegisterCustomPyOp(
void RegisterPythonOperator(
const std::string& op_name,
std::vector<std::string>&& op_inputs,
std::vector<std::string>&& op_outputs,
std::vector<std::string>&& op_attrs,
std::unordered_map<std::string, std::string>&& op_inplace_map,
WrapPythonFunction&& pyop_func,
WrapInferMetaPythonFunction&& pyop_func_infer_meta) {
PythonOperatorFunctionType&& pyop_func,
PythonOperatorInferMetaFunctionType&& pyop_func_infer_meta) {
::paddle::OpMetaInfoBuilder op_meta_info_builder =
::paddle::OpMetaInfoBuilder(std::string(op_name), 0);
op_meta_info_builder.Inputs(std::move(op_inputs))
.Outputs(std::move(op_outputs))
.Attrs(std::move(op_attrs))
.SetInplaceMap(std::move(op_inplace_map))
.SetPyCustomPyOpFunction(pyop_func)
.SetPyCustomPyOpInferMetaFunction(pyop_func_infer_meta);
.SetPythonOperatorFunction(pyop_func)
.SetPythonOperatorInferMetaFunction(pyop_func_infer_meta);

const std::vector<paddle::OpMetaInfo>& op_meta_info_vector =
OpMetaInfoMap::Instance()[op_name];
Expand All @@ -50,17 +50,17 @@ void RegisterCustomPyOp(
const auto& op_meta_info = op_meta_info_vector.back();

auto& inplace_map = OpMetaInfoHelper::GetInplaceMap(op_meta_info);
const auto postfix = inplace_map.empty() ? "" : "_";
const auto suffix = inplace_map.empty() ? "" : "_";

::pir::IrContext* ctx = ::pir::IrContext::Instance();
auto* custom_pyop_dialect =
ctx->GetOrRegisterDialect<paddle::dialect::CustomPyOpDialect>();
auto* python_operator_dialect =
ctx->GetOrRegisterDialect<paddle::dialect::PythonOperatorDialect>();

if (custom_pyop_dialect->HasRegistered(
paddle::framework::kCustomPyDialectPrefix + op_name + postfix)) {
if (python_operator_dialect->HasRegistered(
paddle::framework::kPythonOperatorDialectPrefix + op_name + suffix)) {
return;
}
custom_pyop_dialect->RegisterCustomPyOp(op_meta_info);
python_operator_dialect->RegisterPythonOperator(op_meta_info);
}

} // namespace paddle::framework
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
namespace paddle {
namespace framework {

void RegisterCustomPyOp(
void RegisterPythonOperator(
const std::string& op_name,
std::vector<std::string>&& op_inputs,
std::vector<std::string>&& op_outputs,
std::vector<std::string>&& op_attrs,
std::unordered_map<std::string, std::string>&& op_inplace_map,
WrapPythonFunction&& func,
WrapInferMetaPythonFunction&& infer_meta);
PythonOperatorFunctionType&& func,
PythonOperatorInferMetaFunctionType&& infer_meta);

} // namespace framework
} // namespace paddle
37 changes: 20 additions & 17 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -625,9 +625,9 @@ struct CustomOpInfoInterfaceModel : public OpYamlInfoInterface::Concept {
struct CustomPyOpInfoInterfaceModel : public OpYamlInfoInterface::Concept {
static OpInfoTuple GetPirOpInfo(const std::string& pir_op_name) {
const auto& op_meta =
paddle::framework::detail::GetCustomPyOpInfoByPirName(pir_op_name);
paddle::framework::detail::GetPythonOperatorInfoByPirName(pir_op_name);

// TODO(DrRyanHuang): we may support custom_pyop's grad op in the future
// TODO(DrRyanHuang): we may support py_op's grad op in the future
// const auto* grad_op_meta_ptr =
// paddle::framework::detail::GetGradOpInfoByFwdPirName(pir_op_name);
std::vector<paddle::dialect::OpInputInfo> inputs_info;
Expand All @@ -653,14 +653,15 @@ struct CustomPyOpInfoInterfaceModel : public OpYamlInfoInterface::Concept {
auto& op_attrs = OpMetaInfoHelper::GetAttrs(op_meta);
for (const auto& op_attr : op_attrs) {
auto attr_name_and_type = paddle::ParseAttrStr(op_attr);
// CustomPyOp only has int64_t attr
// PythonOperator only has int64_t attr
const std::string& attr_name = attr_name_and_type[0];
const std::string& attr_type_str = attr_name_and_type[1];
PADDLE_ENFORCE_EQ(attr_type_str,
"int64_t",
common::errors::InvalidArgument(
"CustomPyOp only has two int64_t attributes, which "
"are infer_meta_fn_ptr & fn_ptr."));
PADDLE_ENFORCE_EQ(
attr_type_str,
"int64_t",
common::errors::InvalidArgument(
"PythonOperator only has two int64_t attributes, which "
"are infer_meta_fn_ptr & fn_ptr."));
param_names.push_back(attr_name);
const std::string& attr_pir_type =
CppTypeToAttrTypeMap().at(attr_type_str);
Expand Down Expand Up @@ -1221,26 +1222,28 @@ void CustomOpDialect::RegisterCustomOp(const paddle::OpMetaInfo& op_meta) {
verify_func);
}

CustomPyOpDialect::CustomPyOpDialect(pir::IrContext* context)
: pir::Dialect(name(), context, pir::TypeId::get<CustomPyOpDialect>()) {}
PythonOperatorDialect::PythonOperatorDialect(pir::IrContext* context)
: pir::Dialect(name(), context, pir::TypeId::get<PythonOperatorDialect>()) {
}

void CustomPyOpDialect::PrintType(pir::Type type, std::ostream& os) const {
void PythonOperatorDialect::PrintType(pir::Type type, std::ostream& os) const {
PrintTypeImpl(type, os);
}

void CustomPyOpDialect::PrintAttribute(pir::Attribute attr,
std::ostream& os) const {
void PythonOperatorDialect::PrintAttribute(pir::Attribute attr,
std::ostream& os) const {
PrintAttributeImpl(attr, os);
}

pir::OpPrintFn CustomPyOpDialect::PrintOperation(
pir::OpPrintFn PythonOperatorDialect::PrintOperation(
const pir::Operation& op) const {
return nullptr;
}

void CustomPyOpDialect::RegisterCustomPyOp(const paddle::OpMetaInfo& op_meta) {
void PythonOperatorDialect::RegisterPythonOperator(
const paddle::OpMetaInfo& op_meta) {
pir::TypeId id = IdManager::Instance().CreateId();
std::string op_name = paddle::framework::kCustomPyDialectPrefix +
std::string op_name = paddle::framework::kPythonOperatorDialectPrefix +
OpMetaInfoHelper::GetOpName(op_meta);
std::vector<pir::TypeId> traits;

Expand Down Expand Up @@ -1323,5 +1326,5 @@ pir::OpPrintFn CustomEngineDialect::PrintOperation(

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OperatorDialect)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomOpDialect)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomPyOpDialect)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::PythonOperatorDialect)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomEngineDialect)
12 changes: 6 additions & 6 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ inline bool IsCustomOp(pir::Operation* op) {

inline bool IsCustomPyOp(pir::Operation* op) {
const std::string& op_name = op->name();
return op_name.find("custom_pyop") != op_name.npos;
return op_name.find("py_op") != op_name.npos;
}

inline bool IsCustomEngineOp(pir::Operation* op) {
Expand Down Expand Up @@ -94,19 +94,19 @@ class CustomOpDialect : public pir::Dialect {
std::vector<const char*> op_names_;
};

class CustomPyOpDialect : public pir::Dialect {
class PythonOperatorDialect : public pir::Dialect {
public:
explicit CustomPyOpDialect(pir::IrContext* context);
explicit PythonOperatorDialect(pir::IrContext* context);

constexpr static const char* name() { return "custom_pyop"; }
constexpr static const char* name() { return "py_op"; }

void PrintType(pir::Type type, std::ostream& os) const override;
void PrintAttribute(pir::Attribute type, std::ostream& os) const override;

pir::OpPrintFn PrintOperation(
const pir::Operation& op) const override; // NOLINT

void RegisterCustomPyOp(const paddle::OpMetaInfo& op_meta);
void RegisterPythonOperator(const paddle::OpMetaInfo& op_meta);

bool HasRegistered(const std::string& op_name) {
if (std::find(op_names_.begin(), op_names_.end(), op_name) !=
Expand Down Expand Up @@ -152,5 +152,5 @@ class TEST_API CustomEngineDialect : public pir::Dialect {

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OperatorDialect)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomOpDialect)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomPyOpDialect)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::PythonOperatorDialect)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomEngineDialect)
28 changes: 14 additions & 14 deletions paddle/phi/api/ext/op_meta_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -995,11 +995,11 @@ using InferSpmdFunc = phi::distributed::SpmdInfo (*)(
const std::vector<CustomSpmdInferTensorArg>& inputs,
const std::vector<CustomSpmdInferAttrArg>& attrs);

using WrapPythonFunction =
using PythonOperatorFunctionType =
std::function<std::vector<Tensor>(std::vector<Tensor>&)>;
using IrTensor = paddle::dialect::IrTensor;
using WrapInferMetaPythonFunction = std::function<std::vector<IrTensor>(
const std::vector<paddle::dialect::IrTensor>&)>;
using PythonOperatorInferMetaFunctionType =
std::function<std::vector<IrTensor>(const std::vector<IrTensor>&)>;

class PADDLE_API OpMetaInfo {
public:
Expand Down Expand Up @@ -1031,10 +1031,10 @@ class PADDLE_API OpMetaInfo {
// format: PD_INFER_SPMD_RULE(...)
OpMetaInfo& SetInferSpmdFn(InferSpmdFunc&& func);

// CustomPyOp
OpMetaInfo& SetCustomPyOpFunction(WrapPythonFunction&& func);
// PythonOperator
OpMetaInfo& SetCustomPyOpFunction(PythonOperatorFunctionType&& func);
OpMetaInfo& SetCustomPyOpInferMetaFunction(
WrapInferMetaPythonFunction&& func);
PythonOperatorInferMetaFunctionType&& func);

bool IsGradOp() const;

Expand Down Expand Up @@ -1064,8 +1064,8 @@ class PADDLE_API OpMetaInfo {
InferDtypeFunc infer_dtype_fn_{nullptr};
InferSpmdFunc infer_spmd_fn_{nullptr};
// 3. custom pyop function
WrapPythonFunction pyop_func_{nullptr};
WrapInferMetaPythonFunction pyop_func_infer_meta_{nullptr};
PythonOperatorFunctionType pyop_func_{nullptr};
PythonOperatorInferMetaFunctionType pyop_func_infer_meta_{nullptr};
#ifdef PADDLE_WITH_TENSORRT
TrtGetOutputDimsFunc trt_infer_shape_fn_{nullptr};
std::vector<std::string> trt_supports_format_config_;
Expand All @@ -1092,10 +1092,10 @@ class OpMetaInfoHelper {
static const InferSpmdFunc& GetInferSpmdFn(const paddle::OpMetaInfo& info);

// Python Custom Op
static const WrapPythonFunction& GetPyCustomPyOpFunction(
const paddle::OpMetaInfo& info);
static const WrapInferMetaPythonFunction& GetPyCustomPyOpInferMetaFunction(
static const PythonOperatorFunctionType& GetPythonOperatorFunction(
const paddle::OpMetaInfo& info);
static const PythonOperatorInferMetaFunctionType&
GetPythonOperatorInferMetaFunction(const paddle::OpMetaInfo& info);

#ifdef PADDLE_WITH_TENSORRT
static const TrtGetOutputDimsFunc& GetTrtInferShapeFn(
Expand Down Expand Up @@ -1138,9 +1138,9 @@ class PADDLE_API OpMetaInfoBuilder {
OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func);
OpMetaInfoBuilder& SetInferSpmdFn(InferSpmdFunc func);

OpMetaInfoBuilder& SetPyCustomPyOpFunction(WrapPythonFunction func);
OpMetaInfoBuilder& SetPyCustomPyOpInferMetaFunction(
WrapInferMetaPythonFunction func);
OpMetaInfoBuilder& SetPythonOperatorFunction(PythonOperatorFunctionType func);
OpMetaInfoBuilder& SetPythonOperatorInferMetaFunction(
PythonOperatorInferMetaFunctionType func);

#ifdef PADDLE_WITH_TENSORRT
OpMetaInfoBuilder& SetTrtInferShapeFn(TrtGetOutputDimsFunc func);
Expand Down
29 changes: 16 additions & 13 deletions paddle/phi/api/lib/op_meta_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,14 +450,16 @@ OpMetaInfo& OpMetaInfo::SetInferSpmdFn(InferSpmdFunc&& func) {
infer_spmd_fn_ = std::forward<InferSpmdFunc>(func);
return *this;
}
OpMetaInfo& OpMetaInfo::SetCustomPyOpFunction(WrapPythonFunction&& func) {
pyop_func_ = std::forward<WrapPythonFunction>(func);
OpMetaInfo& OpMetaInfo::SetCustomPyOpFunction(
PythonOperatorFunctionType&& func) {
pyop_func_ = std::forward<PythonOperatorFunctionType>(func);
return *this;
}

OpMetaInfo& OpMetaInfo::SetCustomPyOpInferMetaFunction(
WrapInferMetaPythonFunction&& func) {
pyop_func_infer_meta_ = std::forward<WrapInferMetaPythonFunction>(func);
PythonOperatorInferMetaFunctionType&& func) {
pyop_func_infer_meta_ =
std::forward<PythonOperatorInferMetaFunctionType>(func);
return *this;
}

Expand Down Expand Up @@ -530,13 +532,13 @@ const InferSpmdFunc& OpMetaInfoHelper::GetInferSpmdFn(
}

// Python Custom Op
const WrapPythonFunction& OpMetaInfoHelper::GetPyCustomPyOpFunction(
const PythonOperatorFunctionType& OpMetaInfoHelper::GetPythonOperatorFunction(
const paddle::OpMetaInfo& info) {
return info.pyop_func_;
}

const WrapInferMetaPythonFunction&
OpMetaInfoHelper::GetPyCustomPyOpInferMetaFunction(
const PythonOperatorInferMetaFunctionType&
OpMetaInfoHelper::GetPythonOperatorInferMetaFunction(
const paddle::OpMetaInfo& info) {
return info.pyop_func_infer_meta_;
}
Expand Down Expand Up @@ -700,16 +702,17 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferSpmdFn(InferSpmdFunc func) {
return *this;
}

OpMetaInfoBuilder& OpMetaInfoBuilder::SetPyCustomPyOpFunction(
WrapPythonFunction func) {
info_ptr_->SetCustomPyOpFunction(std::forward<WrapPythonFunction>(func));
OpMetaInfoBuilder& OpMetaInfoBuilder::SetPythonOperatorFunction(
PythonOperatorFunctionType func) {
info_ptr_->SetCustomPyOpFunction(
std::forward<PythonOperatorFunctionType>(func));
return *this;
}

OpMetaInfoBuilder& OpMetaInfoBuilder::SetPyCustomPyOpInferMetaFunction(
WrapInferMetaPythonFunction func) {
OpMetaInfoBuilder& OpMetaInfoBuilder::SetPythonOperatorInferMetaFunction(
PythonOperatorInferMetaFunctionType func) {
info_ptr_->SetCustomPyOpInferMetaFunction(
std::forward<WrapInferMetaPythonFunction>(func));
std::forward<PythonOperatorInferMetaFunctionType>(func));
return *this;
}

Expand Down
Loading