Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
rename custom_pyop -> python_operator
  • Loading branch information
DrRyanHuang committed Dec 15, 2025
commit 0a0c7d55ab53f900aa5f0f50a38e37b8621f6b51
2 changes: 1 addition & 1 deletion paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ configure_file(commit.h.in commit.h)

cc_library(
custom_operator
SRCS custom_operator.cc custom_pyoperator.cc
SRCS custom_operator.cc python_operator.cc
DEPS tensor
attribute
op_registry
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/custom_operator_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
constexpr char kCustomDialectPrefix[] = "custom_op."; // NOLINT
constexpr char kCustomPyDialectPrefix[] = "custom_pyop."; // NOLINT
constexpr char kPythonOperatorDialectPrefix[] = "py_op."; // NOLINT
constexpr char kGradSuffix[] = "_grad"; // NOLINT
constexpr char kDoubleGradSuffix[] = "_grad_grad"; // NOLINT

Expand Down Expand Up @@ -158,9 +158,9 @@ inline static const OpMetaInfo& GetOpInfoByPirName(
}
}

inline static const OpMetaInfo& GetCustomPyOpInfoByPirName(
inline static const OpMetaInfo& GetPythonOperatorInfoByPirName(
const std::string& pir_op_name) {
auto custom_name = pir_op_name.substr(strlen(kCustomPyDialectPrefix));
auto custom_name = pir_op_name.substr(strlen(kPythonOperatorDialectPrefix));
int pos = custom_name.length();

if (custom_name[pos - 1] == '_') {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,28 @@
#include <utility> // for std::move

#include "paddle/fluid/framework/custom_operator_utils.h"
#include "paddle/fluid/framework/custom_pyoperator.h"
#include "paddle/fluid/framework/python_operator.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/phi/api/ext/op_meta_info.h"

namespace paddle::framework {

void RegisterCustomPyOp(
void RegisterPythonOperator(
const std::string& op_name,
std::vector<std::string>&& op_inputs,
std::vector<std::string>&& op_outputs,
std::vector<std::string>&& op_attrs,
std::unordered_map<std::string, std::string>&& op_inplace_map,
WrapPythonFunction&& pyop_func,
WrapInferMetaPythonFunction&& pyop_func_infer_meta) {
PythonOperatorFunctionType&& pyop_func,
PythonOperatorInferMetaFunctionType&& pyop_func_infer_meta) {
::paddle::OpMetaInfoBuilder op_meta_info_builder =
::paddle::OpMetaInfoBuilder(std::string(op_name), 0);
op_meta_info_builder.Inputs(std::move(op_inputs))
.Outputs(std::move(op_outputs))
.Attrs(std::move(op_attrs))
.SetInplaceMap(std::move(op_inplace_map))
.SetPyCustomPyOpFunction(pyop_func)
.SetPyCustomPyOpInferMetaFunction(pyop_func_infer_meta);
.SetPythonOperatorFunction(pyop_func)
.SetPythonOperatorInferMetaFunction(pyop_func_infer_meta);

const std::vector<paddle::OpMetaInfo>& op_meta_info_vector =
OpMetaInfoMap::Instance()[op_name];
Expand All @@ -50,17 +50,17 @@ void RegisterCustomPyOp(
const auto& op_meta_info = op_meta_info_vector.back();

auto& inplace_map = OpMetaInfoHelper::GetInplaceMap(op_meta_info);
const auto postfix = inplace_map.empty() ? "" : "_";
const auto suffix = inplace_map.empty() ? "" : "_";

::pir::IrContext* ctx = ::pir::IrContext::Instance();
auto* custom_pyop_dialect =
ctx->GetOrRegisterDialect<paddle::dialect::CustomPyOpDialect>();
auto* python_operator_dialect =
ctx->GetOrRegisterDialect<paddle::dialect::PythonOperatorDialect>();

if (custom_pyop_dialect->HasRegistered(
paddle::framework::kCustomPyDialectPrefix + op_name + postfix)) {
if (python_operator_dialect->HasRegistered(
paddle::framework::kPythonOperatorDialectPrefix + op_name + suffix)) {
return;
}
custom_pyop_dialect->RegisterCustomPyOp(op_meta_info);
python_operator_dialect->RegisterPythonOperator(op_meta_info);
}

} // namespace paddle::framework
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
namespace paddle {
namespace framework {

void RegisterCustomPyOp(
void RegisterPythonOperator(
const std::string& op_name,
std::vector<std::string>&& op_inputs,
std::vector<std::string>&& op_outputs,
std::vector<std::string>&& op_attrs,
std::unordered_map<std::string, std::string>&& op_inplace_map,
WrapPythonFunction&& func,
WrapInferMetaPythonFunction&& infer_meta);
PythonOperatorFunctionType&& func,
PythonOperatorInferMetaFunctionType&& infer_meta);

} // namespace framework
} // namespace paddle
37 changes: 20 additions & 17 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -625,9 +625,9 @@ struct CustomOpInfoInterfaceModel : public OpYamlInfoInterface::Concept {
struct CustomPyOpInfoInterfaceModel : public OpYamlInfoInterface::Concept {
static OpInfoTuple GetPirOpInfo(const std::string& pir_op_name) {
const auto& op_meta =
paddle::framework::detail::GetCustomPyOpInfoByPirName(pir_op_name);
paddle::framework::detail::GetPythonOperatorInfoByPirName(pir_op_name);

// TODO(DrRyanHuang): we may support custom_pyop's grad op in the future
// TODO(DrRyanHuang): we may support py_op's grad op in the future
// const auto* grad_op_meta_ptr =
// paddle::framework::detail::GetGradOpInfoByFwdPirName(pir_op_name);
std::vector<paddle::dialect::OpInputInfo> inputs_info;
Expand All @@ -653,14 +653,15 @@ struct CustomPyOpInfoInterfaceModel : public OpYamlInfoInterface::Concept {
auto& op_attrs = OpMetaInfoHelper::GetAttrs(op_meta);
for (const auto& op_attr : op_attrs) {
auto attr_name_and_type = paddle::ParseAttrStr(op_attr);
// CustomPyOp only has int64_t attr
// PythonOperator only has int64_t attr
const std::string& attr_name = attr_name_and_type[0];
const std::string& attr_type_str = attr_name_and_type[1];
PADDLE_ENFORCE_EQ(attr_type_str,
"int64_t",
common::errors::InvalidArgument(
"CustomPyOp only has two int64_t attributes, which "
"are infer_meta_fn_ptr & fn_ptr."));
PADDLE_ENFORCE_EQ(
attr_type_str,
"int64_t",
common::errors::InvalidArgument(
"PythonOperator only has two int64_t attributes, which "
"are infer_meta_fn_ptr & fn_ptr."));
param_names.push_back(attr_name);
const std::string& attr_pir_type =
CppTypeToAttrTypeMap().at(attr_type_str);
Expand Down Expand Up @@ -1221,26 +1222,28 @@ void CustomOpDialect::RegisterCustomOp(const paddle::OpMetaInfo& op_meta) {
verify_func);
}

CustomPyOpDialect::CustomPyOpDialect(pir::IrContext* context)
: pir::Dialect(name(), context, pir::TypeId::get<CustomPyOpDialect>()) {}
PythonOperatorDialect::PythonOperatorDialect(pir::IrContext* context)
: pir::Dialect(name(), context, pir::TypeId::get<PythonOperatorDialect>()) {
}

void CustomPyOpDialect::PrintType(pir::Type type, std::ostream& os) const {
void PythonOperatorDialect::PrintType(pir::Type type, std::ostream& os) const {
PrintTypeImpl(type, os);
}

void CustomPyOpDialect::PrintAttribute(pir::Attribute attr,
std::ostream& os) const {
void PythonOperatorDialect::PrintAttribute(pir::Attribute attr,
std::ostream& os) const {
PrintAttributeImpl(attr, os);
}

pir::OpPrintFn CustomPyOpDialect::PrintOperation(
pir::OpPrintFn PythonOperatorDialect::PrintOperation(
const pir::Operation& op) const {
return nullptr;
}

void CustomPyOpDialect::RegisterCustomPyOp(const paddle::OpMetaInfo& op_meta) {
void PythonOperatorDialect::RegisterPythonOperator(
const paddle::OpMetaInfo& op_meta) {
pir::TypeId id = IdManager::Instance().CreateId();
std::string op_name = paddle::framework::kCustomPyDialectPrefix +
std::string op_name = paddle::framework::kPythonOperatorDialectPrefix +
OpMetaInfoHelper::GetOpName(op_meta);
std::vector<pir::TypeId> traits;

Expand Down Expand Up @@ -1323,5 +1326,5 @@ pir::OpPrintFn CustomEngineDialect::PrintOperation(

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OperatorDialect)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomOpDialect)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomPyOpDialect)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::PythonOperatorDialect)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomEngineDialect)
12 changes: 6 additions & 6 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ inline bool IsCustomOp(pir::Operation* op) {

inline bool IsCustomPyOp(pir::Operation* op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的Custom是否要保留?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

part2 #76938 会一起改一下

const std::string& op_name = op->name();
return op_name.find("custom_pyop") != op_name.npos;
return op_name.find("py_op") != op_name.npos;
}

inline bool IsCustomEngineOp(pir::Operation* op) {
Expand Down Expand Up @@ -94,19 +94,19 @@ class CustomOpDialect : public pir::Dialect {
std::vector<const char*> op_names_;
};

class CustomPyOpDialect : public pir::Dialect {
class PythonOperatorDialect : public pir::Dialect {
public:
explicit CustomPyOpDialect(pir::IrContext* context);
explicit PythonOperatorDialect(pir::IrContext* context);

constexpr static const char* name() { return "custom_pyop"; }
constexpr static const char* name() { return "py_op"; }

void PrintType(pir::Type type, std::ostream& os) const override;
void PrintAttribute(pir::Attribute type, std::ostream& os) const override;

pir::OpPrintFn PrintOperation(
const pir::Operation& op) const override; // NOLINT

void RegisterCustomPyOp(const paddle::OpMetaInfo& op_meta);
void RegisterPythonOperator(const paddle::OpMetaInfo& op_meta);

bool HasRegistered(const std::string& op_name) {
if (std::find(op_names_.begin(), op_names_.end(), op_name) !=
Expand Down Expand Up @@ -152,5 +152,5 @@ class TEST_API CustomEngineDialect : public pir::Dialect {

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OperatorDialect)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomOpDialect)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomPyOpDialect)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::PythonOperatorDialect)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomEngineDialect)
28 changes: 14 additions & 14 deletions paddle/phi/api/ext/op_meta_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -995,11 +995,11 @@ using InferSpmdFunc = phi::distributed::SpmdInfo (*)(
const std::vector<CustomSpmdInferTensorArg>& inputs,
const std::vector<CustomSpmdInferAttrArg>& attrs);

using WrapPythonFunction =
using PythonOperatorFunctionType =
std::function<std::vector<Tensor>(std::vector<Tensor>&)>;
using IrTensor = paddle::dialect::IrTensor;
using WrapInferMetaPythonFunction = std::function<std::vector<IrTensor>(
const std::vector<paddle::dialect::IrTensor>&)>;
using PythonOperatorInferMetaFunctionType =
std::function<std::vector<IrTensor>(const std::vector<IrTensor>&)>;

class PADDLE_API OpMetaInfo {
public:
Expand Down Expand Up @@ -1031,10 +1031,10 @@ class PADDLE_API OpMetaInfo {
// format: PD_INFER_SPMD_RULE(...)
OpMetaInfo& SetInferSpmdFn(InferSpmdFunc&& func);

// CustomPyOp
OpMetaInfo& SetCustomPyOpFunction(WrapPythonFunction&& func);
// PythonOperator
OpMetaInfo& SetCustomPyOpFunction(PythonOperatorFunctionType&& func);
OpMetaInfo& SetCustomPyOpInferMetaFunction(
WrapInferMetaPythonFunction&& func);
PythonOperatorInferMetaFunctionType&& func);

bool IsGradOp() const;

Expand Down Expand Up @@ -1064,8 +1064,8 @@ class PADDLE_API OpMetaInfo {
InferDtypeFunc infer_dtype_fn_{nullptr};
InferSpmdFunc infer_spmd_fn_{nullptr};
// 3. custom pyop function
WrapPythonFunction pyop_func_{nullptr};
WrapInferMetaPythonFunction pyop_func_infer_meta_{nullptr};
PythonOperatorFunctionType pyop_func_{nullptr};
PythonOperatorInferMetaFunctionType pyop_func_infer_meta_{nullptr};
#ifdef PADDLE_WITH_TENSORRT
TrtGetOutputDimsFunc trt_infer_shape_fn_{nullptr};
std::vector<std::string> trt_supports_format_config_;
Expand All @@ -1092,10 +1092,10 @@ class OpMetaInfoHelper {
static const InferSpmdFunc& GetInferSpmdFn(const paddle::OpMetaInfo& info);

// Python Custom Op
static const WrapPythonFunction& GetPyCustomPyOpFunction(
const paddle::OpMetaInfo& info);
static const WrapInferMetaPythonFunction& GetPyCustomPyOpInferMetaFunction(
static const PythonOperatorFunctionType& GetPythonOperatorFunction(
const paddle::OpMetaInfo& info);
static const PythonOperatorInferMetaFunctionType&
GetPythonOperatorInferMetaFunction(const paddle::OpMetaInfo& info);

#ifdef PADDLE_WITH_TENSORRT
static const TrtGetOutputDimsFunc& GetTrtInferShapeFn(
Expand Down Expand Up @@ -1138,9 +1138,9 @@ class PADDLE_API OpMetaInfoBuilder {
OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func);
OpMetaInfoBuilder& SetInferSpmdFn(InferSpmdFunc func);

OpMetaInfoBuilder& SetPyCustomPyOpFunction(WrapPythonFunction func);
OpMetaInfoBuilder& SetPyCustomPyOpInferMetaFunction(
WrapInferMetaPythonFunction func);
OpMetaInfoBuilder& SetPythonOperatorFunction(PythonOperatorFunctionType func);
OpMetaInfoBuilder& SetPythonOperatorInferMetaFunction(
PythonOperatorInferMetaFunctionType func);

#ifdef PADDLE_WITH_TENSORRT
OpMetaInfoBuilder& SetTrtInferShapeFn(TrtGetOutputDimsFunc func);
Expand Down
29 changes: 16 additions & 13 deletions paddle/phi/api/lib/op_meta_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,14 +450,16 @@ OpMetaInfo& OpMetaInfo::SetInferSpmdFn(InferSpmdFunc&& func) {
infer_spmd_fn_ = std::forward<InferSpmdFunc>(func);
return *this;
}
OpMetaInfo& OpMetaInfo::SetCustomPyOpFunction(WrapPythonFunction&& func) {
pyop_func_ = std::forward<WrapPythonFunction>(func);
OpMetaInfo& OpMetaInfo::SetCustomPyOpFunction(
PythonOperatorFunctionType&& func) {
pyop_func_ = std::forward<PythonOperatorFunctionType>(func);
return *this;
}

OpMetaInfo& OpMetaInfo::SetCustomPyOpInferMetaFunction(
WrapInferMetaPythonFunction&& func) {
pyop_func_infer_meta_ = std::forward<WrapInferMetaPythonFunction>(func);
PythonOperatorInferMetaFunctionType&& func) {
pyop_func_infer_meta_ =
std::forward<PythonOperatorInferMetaFunctionType>(func);
return *this;
}

Expand Down Expand Up @@ -530,13 +532,13 @@ const InferSpmdFunc& OpMetaInfoHelper::GetInferSpmdFn(
}

// Python Custom Op
const WrapPythonFunction& OpMetaInfoHelper::GetPyCustomPyOpFunction(
const PythonOperatorFunctionType& OpMetaInfoHelper::GetPythonOperatorFunction(
const paddle::OpMetaInfo& info) {
return info.pyop_func_;
}

const WrapInferMetaPythonFunction&
OpMetaInfoHelper::GetPyCustomPyOpInferMetaFunction(
const PythonOperatorInferMetaFunctionType&
OpMetaInfoHelper::GetPythonOperatorInferMetaFunction(
const paddle::OpMetaInfo& info) {
return info.pyop_func_infer_meta_;
}
Expand Down Expand Up @@ -700,16 +702,17 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferSpmdFn(InferSpmdFunc func) {
return *this;
}

OpMetaInfoBuilder& OpMetaInfoBuilder::SetPyCustomPyOpFunction(
WrapPythonFunction func) {
info_ptr_->SetCustomPyOpFunction(std::forward<WrapPythonFunction>(func));
OpMetaInfoBuilder& OpMetaInfoBuilder::SetPythonOperatorFunction(
PythonOperatorFunctionType func) {
info_ptr_->SetCustomPyOpFunction(
std::forward<PythonOperatorFunctionType>(func));
return *this;
}

OpMetaInfoBuilder& OpMetaInfoBuilder::SetPyCustomPyOpInferMetaFunction(
WrapInferMetaPythonFunction func) {
OpMetaInfoBuilder& OpMetaInfoBuilder::SetPythonOperatorInferMetaFunction(
PythonOperatorInferMetaFunctionType func) {
info_ptr_->SetCustomPyOpInferMetaFunction(
std::forward<WrapInferMetaPythonFunction>(func));
std::forward<PythonOperatorInferMetaFunctionType>(func));
return *this;
}

Expand Down
Loading