Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmake/inference_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,10 @@ copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/enforce.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/)
copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/ddim.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/)
copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/utils/string/*.h
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ configure_file(commit.h.in commit.h)

cc_library(
custom_operator
SRCS custom_operator.cc
SRCS custom_operator.cc python_operator.cc
DEPS tensor
attribute
op_registry
Expand Down
39 changes: 36 additions & 3 deletions paddle/fluid/framework/custom_operator_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ limitations under the License. */

namespace paddle {
namespace framework {
constexpr char kCustomDialectPrefix[] = "custom_op."; // NOLINT
constexpr char kGradSuffix[] = "_grad"; // NOLINT
constexpr char kDoubleGradSuffix[] = "_grad_grad"; // NOLINT
constexpr char kCustomDialectPrefix[] = "custom_op."; // NOLINT
constexpr char kPythonOperatorDialectPrefix[] = "py_op."; // NOLINT
constexpr char kGradSuffix[] = "_grad"; // NOLINT
constexpr char kDoubleGradSuffix[] = "_grad_grad"; // NOLINT

namespace detail {

Expand Down Expand Up @@ -157,6 +158,38 @@ inline static const OpMetaInfo& GetOpInfoByPirName(
}
}

inline static const OpMetaInfo& GetPythonOperatorInfoByPirName(
const std::string& pir_op_name) {
auto custom_name = pir_op_name.substr(strlen(kPythonOperatorDialectPrefix));
int pos = custom_name.length();

if (custom_name[pos - 1] == '_') {
custom_name = custom_name.substr(0, pos - 1);
}

pos = custom_name.length();
if (custom_name.find(kDoubleGradSuffix) != custom_name.npos) {
pos = custom_name.find(kDoubleGradSuffix);
} else if (custom_name.find(kGradSuffix) != custom_name.npos) {
pos = custom_name.find(kGradSuffix);
}
auto custom_name_prefix = custom_name.substr(0, pos);
auto map_iter =
paddle::OpMetaInfoMap::Instance().GetMap().find(custom_name_prefix);
if (map_iter == paddle::OpMetaInfoMap::Instance().GetMap().end()) {
PADDLE_THROW("The info of custom python op : " + custom_name +
" is not exists!");
}
const auto& vec_op_meta = map_iter->second;
if (custom_name.find(kDoubleGradSuffix) != custom_name.npos) {
return vec_op_meta[2];
} else if (custom_name.find(kGradSuffix) != custom_name.npos) {
return vec_op_meta[1];
} else {
return vec_op_meta[0];
}
}

inline static bool HasGradOp(const std::string& fwd_pir_op_name) {
auto custom_name = fwd_pir_op_name.substr(strlen(kCustomDialectPrefix));
int pos = custom_name.length();
Expand Down
67 changes: 67 additions & 0 deletions paddle/fluid/framework/python_operator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <utility> // for std::move

#include "paddle/fluid/framework/custom_operator_utils.h"
#include "paddle/fluid/framework/python_operator.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/phi/api/ext/op_meta_info.h"

namespace paddle::framework {

void RegisterPythonOperator(
const std::string& op_name,
std::vector<std::string>&& op_inputs,
std::vector<std::string>&& op_outputs,
std::vector<std::string>&& op_attrs,
std::unordered_map<std::string, std::string>&& op_inplace_map,
PythonOperatorFunctionType&& pyop_func,
PythonOperatorInferMetaFunctionType&& pyop_func_infer_meta) {
::paddle::OpMetaInfoBuilder op_meta_info_builder =
::paddle::OpMetaInfoBuilder(std::string(op_name), 0);
op_meta_info_builder.Inputs(std::move(op_inputs))
.Outputs(std::move(op_outputs))
.Attrs(std::move(op_attrs))
.SetInplaceMap(std::move(op_inplace_map))
.SetPythonOperatorFunction(pyop_func)
.SetPythonOperatorInferMetaFunction(pyop_func_infer_meta);

const std::vector<paddle::OpMetaInfo>& op_meta_info_vector =
OpMetaInfoMap::Instance()[op_name];

PADDLE_ENFORCE_EQ(op_meta_info_vector.size(),
1,
common::errors::OutOfRange(
"Current op_name(%s) must not be registered more "
"than once, because it does not support gradient op.",
op_name));

const auto& op_meta_info = op_meta_info_vector.back();

auto& inplace_map = OpMetaInfoHelper::GetInplaceMap(op_meta_info);
const auto suffix = inplace_map.empty() ? "" : "_";

::pir::IrContext* ctx = ::pir::IrContext::Instance();
auto* python_operator_dialect =
ctx->GetOrRegisterDialect<paddle::dialect::PythonOperatorDialect>();

if (python_operator_dialect->HasRegistered(
paddle::framework::kPythonOperatorDialectPrefix + op_name + suffix)) {
return;
}
python_operator_dialect->RegisterPythonOperator(op_meta_info);
}

} // namespace paddle::framework
37 changes: 37 additions & 0 deletions paddle/fluid/framework/python_operator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/pir/dialect/operator/ir/ir_tensor.h"
#include "paddle/phi/api/ext/op_meta_info.h"

namespace paddle {
namespace framework {

void RegisterPythonOperator(
const std::string& op_name,
std::vector<std::string>&& op_inputs,
std::vector<std::string>&& op_outputs,
std::vector<std::string>&& op_attrs,
std::unordered_map<std::string, std::string>&& op_inplace_map,
PythonOperatorFunctionType&& func,
PythonOperatorInferMetaFunctionType&& infer_meta);

} // namespace framework
} // namespace paddle
45 changes: 45 additions & 0 deletions paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,50 @@ pir::OpPrintFn CustomKernelDialect::PrintOperation(
printer.PrintOpReturnType(op);
};
}

PythonFunctionDialect::PythonFunctionDialect(pir::IrContext *context)
: pir::Dialect(name(), context, pir::TypeId::get<PythonFunctionDialect>()) {
initialize();
}

void PythonFunctionDialect::initialize() {
RegisterOps<dialect::PythonFunctionOp>();
}

void PythonFunctionDialect::PrintType(pir::Type type, std::ostream &os) const {
PrintKernelType(type, os);
}

void PythonFunctionDialect::PrintAttribute(pir::Attribute attr,
std::ostream &os) const {
PrintKernelAttribute(attr, os);
}

pir::OpPrintFn PythonFunctionDialect::PrintOperation(
const pir::Operation &op) const {
return [](const pir::Operation &op, pir::IrPrinter &printer) {
auto &os = printer.os;
printer.PrintOpResult(op);
os << " =";
auto py_func_op = op.dyn_cast<PythonFunctionOp>();
std::string kernel_name = py_func_op.kernel_name();
if (op.attributes().count("is_inplace") != 0 &&
op.attributes()
.at("is_inplace")
.dyn_cast<pir::BoolAttribute>()
.data()) {
kernel_name = kernel_name + "_";
}
os << " \"" << kernel_name << "(py_func)\"";
printer.PrintOpOperands(op);
printer.PrintAttributeMap(op);
os << " :";
printer.PrintOperandsType(op);
os << " -> ";
printer.PrintOpReturnType(op);
};
}

#ifdef PADDLE_WITH_DNNL
OneDNNKernelDialect::OneDNNKernelDialect(pir::IrContext *context)
: pir::Dialect(name(), context, pir::TypeId::get<OneDNNKernelDialect>()) {
Expand Down Expand Up @@ -258,6 +302,7 @@ pir::OpPrintFn OneDNNKernelDialect::PrintOperation(

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::KernelDialect)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomKernelDialect)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::PythonFunctionDialect)
#ifdef PADDLE_WITH_DNNL
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNKernelDialect)
#endif
18 changes: 18 additions & 0 deletions paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@ class CustomKernelDialect : public pir::Dialect {
void initialize();
};

class PythonFunctionDialect : public pir::Dialect {
public:
explicit PythonFunctionDialect(pir::IrContext* context);

static const char* name() { return "py_func"; }

void PrintType(pir::Type type, std::ostream& os) const override;

void PrintAttribute(pir::Attribute attr, std::ostream& os) const override;

pir::OpPrintFn PrintOperation(
const pir::Operation& op) const override; // NOLINT

private:
void initialize();
};

#ifdef PADDLE_WITH_DNNL
class OneDNNKernelDialect : public pir::Dialect {
public:
Expand All @@ -77,6 +94,7 @@ class OneDNNKernelDialect : public pir::Dialect {

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::KernelDialect)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomKernelDialect)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::PythonFunctionDialect)
#ifdef PADDLE_WITH_DNNL
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNKernelDialect)
#endif
41 changes: 41 additions & 0 deletions paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,46 @@ phi::KernelKey CustomKernelOp::kernel_key() {
return attributes().at("kernel_key").dyn_cast<KernelAttribute>().data();
}

const char* PythonFunctionOp::attributes_name[attributes_num] = { // NOLINT
"op_name",
"kernel_name",
"kernel_key"};

void PythonFunctionOp::VerifySig() {
VLOG(4) << "Verifying inputs, outputs and attributes for: PythonFunctionOp.";
auto& attributes = this->attributes();

PADDLE_ENFORCE(attributes.count("op_name") > 0 &&
attributes.at("op_name").isa<pir::StrAttribute>(),
common::errors::PreconditionNotMet(
"Type of attribute: op_name is not right."));

PADDLE_ENFORCE(attributes.count("kernel_name") > 0 &&
attributes.at("kernel_name").isa<pir::StrAttribute>(),
common::errors::PreconditionNotMet(
"Type of attribute: kernel_name is not right."));

PADDLE_ENFORCE(attributes.count("kernel_key") > 0 &&
attributes.at("kernel_key").isa<KernelAttribute>(),
common::errors::PreconditionNotMet(
"Type of attribute: kernel_key is not right."));
}

std::string PythonFunctionOp::op_name() {
return attributes().at("op_name").dyn_cast<pir::StrAttribute>().AsString();
}

std::string PythonFunctionOp::kernel_name() {
return attributes()
.at("kernel_name")
.dyn_cast<pir::StrAttribute>()
.AsString();
}

phi::KernelKey PythonFunctionOp::kernel_key() {
return attributes().at("kernel_key").dyn_cast<KernelAttribute>().data();
}

#ifdef PADDLE_WITH_DNNL
const char* OneDNNPhiKernelOp::attributes_name[attributes_num] = { // NOLINT
"op_name",
Expand Down Expand Up @@ -264,6 +304,7 @@ phi::KernelKey OneDNNLegacyKernelOp::kernel_key() {
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::PhiKernelOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::LegacyKernelOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomKernelOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::PythonFunctionOp)
#ifdef PADDLE_WITH_DNNL
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNPhiKernelOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNMixedPhiKernelOp)
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/pir/dialect/kernel/ir/kernel_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ class CustomKernelOp : public pir::Op<CustomKernelOp> {
void VerifySig();
};

class PythonFunctionOp : public pir::Op<PythonFunctionOp> {
public:
using Op::Op;
static const char *name() { return "py_func"; }
static constexpr uint32_t attributes_num = 3;
static const char *attributes_name[attributes_num];
std::string op_name();
std::string kernel_name();
phi::KernelKey kernel_key();
void VerifySig();
};

#ifdef PADDLE_WITH_DNNL
class OneDNNPhiKernelOp : public pir::Op<OneDNNPhiKernelOp> {
public:
Expand Down Expand Up @@ -100,6 +112,7 @@ class OneDNNLegacyKernelOp : public pir::Op<OneDNNLegacyKernelOp> {
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::PhiKernelOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::LegacyKernelOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomKernelOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::PythonFunctionOp)
#ifdef PADDLE_WITH_DNNL
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNPhiKernelOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNMixedPhiKernelOp)
Expand Down
Loading
Loading