Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/framework/CMakeLists.txt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

统一一下命名吧:

  • dialect name 为 py_op
  • 在文件名等一些非缩写场景,使用 python_operator,两者一起不缩写

避免和 custom op 相互之间冲突

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lower 后的 kernel dialect 叫啥合适呢?之前叫 custom_py_func,改成 python_operator_function ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

自定义算子的叫啥?

Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ configure_file(commit.h.in commit.h)

cc_library(
custom_operator
SRCS custom_operator.cc
SRCS custom_operator.cc custom_pyoperator.cc
DEPS tensor
attribute
op_registry
Expand Down
39 changes: 36 additions & 3 deletions paddle/fluid/framework/custom_operator_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ limitations under the License. */

namespace paddle {
namespace framework {
constexpr char kCustomDialectPrefix[] = "custom_op."; // NOLINT
constexpr char kGradSuffix[] = "_grad"; // NOLINT
constexpr char kDoubleGradSuffix[] = "_grad_grad"; // NOLINT
constexpr char kCustomDialectPrefix[] = "custom_op."; // NOLINT
constexpr char kCustomPyDialectPrefix[] = "custom_pyop."; // NOLINT
constexpr char kGradSuffix[] = "_grad"; // NOLINT
constexpr char kDoubleGradSuffix[] = "_grad_grad"; // NOLINT

namespace detail {

Expand Down Expand Up @@ -157,6 +158,38 @@ inline static const OpMetaInfo& GetOpInfoByPirName(
}
}

inline static const OpMetaInfo& GetCustomPyOpInfoByPirName(
const std::string& pir_op_name) {
auto custom_name = pir_op_name.substr(strlen(kCustomPyDialectPrefix));
int pos = custom_name.length();

if (custom_name[pos - 1] == '_') {
custom_name = custom_name.substr(0, pos - 1);
}

pos = custom_name.length();
if (custom_name.find(kDoubleGradSuffix) != custom_name.npos) {
pos = custom_name.find(kDoubleGradSuffix);
} else if (custom_name.find(kGradSuffix) != custom_name.npos) {
pos = custom_name.find(kGradSuffix);
}
auto custom_name_prefix = custom_name.substr(0, pos);
auto map_iter =
paddle::OpMetaInfoMap::Instance().GetMap().find(custom_name_prefix);
if (map_iter == paddle::OpMetaInfoMap::Instance().GetMap().end()) {
PADDLE_THROW("The info of custom python op : " + custom_name +
" is not exists!");
}
const auto& vec_op_meta = map_iter->second;
if (custom_name.find(kDoubleGradSuffix) != custom_name.npos) {
return vec_op_meta[2];
} else if (custom_name.find(kGradSuffix) != custom_name.npos) {
return vec_op_meta[1];
} else {
return vec_op_meta[0];
}
}

inline static bool HasGradOp(const std::string& fwd_pir_op_name) {
auto custom_name = fwd_pir_op_name.substr(strlen(kCustomDialectPrefix));
int pos = custom_name.length();
Expand Down
66 changes: 66 additions & 0 deletions paddle/fluid/framework/custom_pyoperator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <utility> // for std::move

#include "paddle/fluid/framework/custom_operator_utils.h"
#include "paddle/fluid/framework/custom_pyoperator.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/phi/api/ext/op_meta_info.h"

namespace paddle::framework {

void RegisterCustomPyOp(
const std::string& op_name,
std::vector<std::string>&& op_inputs,
std::vector<std::string>&& op_outputs,
std::vector<std::string>&& op_attrs,
std::unordered_map<std::string, std::string>&& op_inplace_map,
WrapPythonFunction&& pyop_func,
WrapInferMetaPythonFunction&& pyop_func_infer_meta) {
::paddle::OpMetaInfoBuilder op_meta_info_builder =
::paddle::OpMetaInfoBuilder(std::string(op_name), 0);
op_meta_info_builder.Inputs(std::move(op_inputs))
.Outputs(std::move(op_outputs))
.Attrs(std::move(op_attrs))
.SetInplaceMap(std::move(op_inplace_map))
.SetPyCustomPyOpFunction(pyop_func)
.SetPyCustomPyOpInferMetaFunction(pyop_func_infer_meta);

const std::vector<paddle::OpMetaInfo>& op_meta_info_vector =
OpMetaInfoMap::Instance()[op_name];

PADDLE_ENFORCE_EQ(op_meta_info_vector.size(),
1,
common::errors::OutOfRange(
"Current op_name(%s) must not be registered more "
"than one, because it don't support gradient op."));

const auto& op_meta_info = op_meta_info_vector.back();

auto& inplace_map = OpMetaInfoHelper::GetInplaceMap(op_meta_info);
const auto postfix = inplace_map.empty() ? "" : "_";

::pir::IrContext* ctx = ::pir::IrContext::Instance();
auto* custom_pyop_dialect =
ctx->GetOrRegisterDialect<paddle::dialect::CustomPyOpDialect>();

if (custom_pyop_dialect->HasRegistered(
paddle::framework::kCustomPyDialectPrefix + op_name + postfix)) {
return;
}
custom_pyop_dialect->RegisterCustomPyOp(op_meta_info);
}

} // namespace paddle::framework
37 changes: 37 additions & 0 deletions paddle/fluid/framework/custom_pyoperator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/pir/dialect/operator/ir/ir_tensor.h"
#include "paddle/phi/api/ext/op_meta_info.h"

namespace paddle {
namespace framework {

void RegisterCustomPyOp(
const std::string& op_name,
std::vector<std::string>&& op_inputs,
std::vector<std::string>&& op_outputs,
std::vector<std::string>&& op_attrs,
std::unordered_map<std::string, std::string>&& op_inplace_map,
WrapPythonFunction&& func,
WrapInferMetaPythonFunction&& infer_meta);

} // namespace framework
} // namespace paddle
155 changes: 155 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,86 @@ struct CustomOpInfoInterfaceModel : public OpYamlInfoInterface::Concept {
CustomOpInfoInterfaceModel() : OpYamlInfoInterface::Concept(GetPirOpInfo) {}
};

struct CustomPyOpInfoInterfaceModel : public OpYamlInfoInterface::Concept {
static OpInfoTuple GetPirOpInfo(const std::string& pir_op_name) {
const auto& op_meta =
paddle::framework::detail::GetCustomPyOpInfoByPirName(pir_op_name);

// TODO(DrRyanHuang): we may support custom_pyop's grad op in the future
// const auto* grad_op_meta_ptr =
// paddle::framework::detail::GetGradOpInfoByFwdPirName(pir_op_name);
std::vector<paddle::dialect::OpInputInfo> inputs_info;
std::vector<paddle::dialect::OpAttributeInfo> attributes_info;
std::vector<paddle::dialect::OpOutputInfo> outputs_info;
std::vector<std::string> param_names;

// translate input info
auto& op_input_names = OpMetaInfoHelper::GetInputs(op_meta);
for (const auto& input_name : op_input_names) {
param_names.push_back(input_name);
// Now, we only support dense tensor as input.
inputs_info.push_back(paddle::dialect::OpInputInfo{
input_name,
/*input_type=*/"paddle::dialect::DenseTensorType",
/*optional=*/false,
/*no_need_buffer=*/false,
/*is_mutable_attribute=*/false,
/*with_grad_semantic=*/false});
}

// translate attr info
auto& op_attrs = OpMetaInfoHelper::GetAttrs(op_meta);
for (const auto& op_attr : op_attrs) {
auto attr_name_and_type = paddle::ParseAttrStr(op_attr);
// CustomPyOp only has int64_t attr
const std::string& attr_name = attr_name_and_type[0];
const std::string& attr_type_str = attr_name_and_type[1];
PADDLE_ENFORCE_EQ(attr_type_str,
"int64_t",
common::errors::InvalidArgument(
"CustomPyOp only has two int64_t attributes, which "
"are infer_meta_fn_ptr & fn_ptr."));
param_names.push_back(attr_name);
const std::string& attr_pir_type =
CppTypeToAttrTypeMap().at(attr_type_str);
attributes_info.emplace_back(attr_name, attr_pir_type, "");
}

// translate output info
auto& op_output_names = OpMetaInfoHelper::GetOutputs(op_meta);
for (const auto& output_name : op_output_names) {
// Now, we only support dense tensor as output.
outputs_info.push_back(paddle::dialect::OpOutputInfo{
output_name,
/*type_name=*/"paddle::dialect::DenseTensorType",
/*is_optional=*/false,
/*intermediate=*/false});
}

auto& inplace_maps = OpMetaInfoHelper::GetInplaceReverseMap(op_meta);
if (!inplace_maps.empty()) {
VLOG(3) << "Register Custom Python Operator: op inplace_map: "
<< string::join_strings(inplace_maps, ',', [](auto& pair) {
return pair.first + ": " + pair.second;
});
}
std::vector<std::pair<std::string, std::string>> vec_inplace;
for (const auto& inplace_map : inplace_maps) {
vec_inplace.emplace_back(inplace_map);
}

// we only need kernel params name in run_time_info
paddle::dialect::OpRunTimeInfo run_time_info =
paddle::dialect::OpRunTimeInfo(
"", {}, "", param_names, {}, {}, vec_inplace, {});

return std::make_tuple(
inputs_info, attributes_info, outputs_info, run_time_info, "");
}

CustomPyOpInfoInterfaceModel() : OpYamlInfoInterface::Concept(GetPirOpInfo) {}
};

struct CustomOpVjpInterfaceModel : public VjpInterface::Concept {
static std::vector<std::vector<pir::Value>> CustomOpVjp(
pir::Operation* op,
Expand Down Expand Up @@ -1141,6 +1221,80 @@ void CustomOpDialect::RegisterCustomOp(const paddle::OpMetaInfo& op_meta) {
verify_func);
}

CustomPyOpDialect::CustomPyOpDialect(pir::IrContext* context)
: pir::Dialect(name(), context, pir::TypeId::get<CustomPyOpDialect>()) {}

void CustomPyOpDialect::PrintType(pir::Type type, std::ostream& os) const {
PrintTypeImpl(type, os);
}

void CustomPyOpDialect::PrintAttribute(pir::Attribute attr,
std::ostream& os) const {
PrintAttributeImpl(attr, os);
}

pir::OpPrintFn CustomPyOpDialect::PrintOperation(
const pir::Operation& op) const {
return nullptr;
}

void CustomPyOpDialect::RegisterCustomPyOp(const paddle::OpMetaInfo& op_meta) {
pir::TypeId id = IdManager::Instance().CreateId();
std::string op_name = paddle::framework::kCustomPyDialectPrefix +
OpMetaInfoHelper::GetOpName(op_meta);
std::vector<pir::TypeId> traits;

auto& inplace_map = OpMetaInfoHelper::GetInplaceMap(op_meta);
if (!inplace_map.empty()) {
op_name += "_";
traits.push_back(pir::TypeId::get<paddle::dialect::InplaceTrait>());
}

char* op_name_c = new char[op_name.size() + 1];
snprintf(op_name_c, op_name.size() + 1, "%s", op_name.c_str());
op_names_.push_back(op_name_c);

auto& op_attrs = OpMetaInfoHelper::GetAttrs(op_meta);
std::vector<std::string> attr_names;
for (const auto& op_attr : op_attrs) {
auto attr_name_and_type = paddle::ParseAttrStr(op_attr);
auto attr_name = attr_name_and_type[0];
attr_names.push_back(attr_name);
}
const char** attr_name =
AttributeManager::Instance().ToCharPointers(attr_names);
uint32_t attr_num = attr_names.size();

std::cout << "attr_num: " << attr_num << std::endl;

std::set<pir::InterfaceValue> interface_values;
pir::InterfaceValue op_info_interface =
pir::InterfaceValue::Get<OpYamlInfoInterface,
CustomPyOpInfoInterfaceModel>();
interface_values.insert(std::move(op_info_interface));

// TODO(DrRyanHuang): Currently, we do not support vjp for customPyOp.
// if (paddle::framework::detail::HasGradOp(op_name)) {
// pir::InterfaceValue vjp_interface =
// pir::InterfaceValue::Get<VjpInterface,
// CustomPyOpVjpInterfaceModel>();
// interface_values.insert(std::move(vjp_interface));
// }

// Currently we set empty verify function and will reset it if it is used in
// future.
pir::VerifyPtr verify_func = [](pir::Operation* op) {};
ir_context()->RegisterOpInfo(this,
id,
op_names_.back(),
std::move(interface_values),
traits,
attr_num,
attr_name,
verify_func,
verify_func);
}

// customEngineDialect

CustomEngineDialect::CustomEngineDialect(pir::IrContext* context)
Expand Down Expand Up @@ -1169,4 +1323,5 @@ pir::OpPrintFn CustomEngineDialect::PrintOperation(

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OperatorDialect)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomOpDialect)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomPyOpDialect)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomEngineDialect)
32 changes: 32 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ inline bool IsCustomOp(pir::Operation* op) {
return op_name.find("custom_op") != op_name.npos;
}

inline bool IsCustomPyOp(pir::Operation* op) {
const std::string& op_name = op->name();
return op_name.find("custom_pyop") != op_name.npos;
}

inline bool IsCustomEngineOp(pir::Operation* op) {
std::string op_name = op->name();
return op_name.find("custom_engine") != op_name.npos;
Expand Down Expand Up @@ -89,6 +94,32 @@ class CustomOpDialect : public pir::Dialect {
std::vector<const char*> op_names_;
};

class CustomPyOpDialect : public pir::Dialect {
public:
explicit CustomPyOpDialect(pir::IrContext* context);

constexpr static const char* name() { return "custom_pyop"; }

void PrintType(pir::Type type, std::ostream& os) const override;
void PrintAttribute(pir::Attribute type, std::ostream& os) const override;

pir::OpPrintFn PrintOperation(
const pir::Operation& op) const override; // NOLINT

void RegisterCustomPyOp(const paddle::OpMetaInfo& op_meta);

bool HasRegistered(const std::string& op_name) {
if (std::find(op_names_.begin(), op_names_.end(), op_name) !=
op_names_.end()) {
return true;
}
return false;
}

private:
std::vector<const char*> op_names_;
};

class TEST_API CustomEngineDialect : public pir::Dialect {
public:
explicit CustomEngineDialect(pir::IrContext* context);
Expand Down Expand Up @@ -121,4 +152,5 @@ class TEST_API CustomEngineDialect : public pir::Dialect {

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OperatorDialect)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomOpDialect)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomPyOpDialect)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomEngineDialect)
Loading
Loading