Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/op_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ inline bool IsCustomOp(pir::Operation* op) {
return op_name.find("custom_op") != op_name.npos;
}

inline bool IsCustomPyOp(pir::Operation* op) {
inline bool IsPythonOp(pir::Operation* op) {
const std::string& op_name = op->name();
return op_name.find("py_op") != op_name.npos;
}
Expand Down
113 changes: 112 additions & 1 deletion paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2505,6 +2505,105 @@ void HandleForCustomOp(
block->push_back(op);
}

void HandleForPythonOp(
pir::IrContext* ctx,
pir::Operation* op_item,
const phi::KernelKey& kernel_key,
const phi::Place place,
const OpYamlInfoParser* op_info_parser,
std::unordered_map<pir::Operation*, pir::Operation*>* map_op_pair,
std::unordered_map<pir::Value, pir::Value>* map_value_pair,
pir::Block* block) {
// Prepare output
std::vector<pir::Type> op_output_types;
for (size_t i = 0; i < op_item->num_results(); ++i) {
phi::Place out_place = phi::TransToPhiPlace(kernel_key.backend());
PushBackOutputTypes(ctx,
op_item,
op_item->result(i).type(),
out_place,
kernel_key,
&op_output_types);
}

// Prepare input
std::vector<pir::Value> vec_inputs;
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
PADDLE_ENFORCE_EQ(
map_value_pair->count(cur_in),
true,
common::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST in map pair", i, op_item->name()));

auto new_in = map_value_pair->at(cur_in);
auto new_in_type = new_in.type();

if (new_in_type.isa<AllocatedDenseTensorType>()) {
auto in_place = new_in_type.dyn_cast<AllocatedDenseTensorType>().place();
// need trans from GPU_PINNED to GPU, refer to PR#41972
if (phi::AllocationType::GPUPINNED == place.GetType()) {
// build memcopy op
auto out_place = phi::TransToPhiPlace(phi::Backend::GPU);
auto new_in_alloc_type =
new_in_type.dyn_cast<AllocatedDenseTensorType>();
auto out_type =
AllocatedDenseTensorType::get(ctx,
out_place,
new_in_alloc_type.dtype(),
new_in_alloc_type.dims(),
new_in_alloc_type.data_layout(),
new_in_alloc_type.lod(),
new_in_alloc_type.offset());
new_in = AddPlaceTransferOp(
new_in, out_type, in_place, out_place, kernel_key, block);
}
}

vec_inputs.push_back(new_in);
}

// Prepare attr
std::unordered_map<std::string, pir::Attribute> op_attribute{
{"op_name", pir::StrAttribute::get(ctx, op_item->name())},
{"kernel_name", pir::StrAttribute::get(ctx, op_item->name())},
{"kernel_key", KernelAttribute::get(ctx, kernel_key)}};

auto op_attr_map = op_item->attributes();
for (auto& map_item : op_attr_map) {
op_attribute.emplace(map_item.first, map_item.second);
}
if (op_item->HasTrait<InplaceTrait>()) {
op_attribute.emplace("is_inplace", pir::BoolAttribute::get(ctx, true));
}
op_attribute.emplace("origin_id",
pir::Int64Attribute::get(ctx, op_item->id()));

VLOG(6) << "Lower custom pyop: " << op_item->name()
<< " to : " << PythonFunctionOp::name();

pir::OpInfo py_func_op_info =
ctx->GetRegisteredOpInfo(PythonFunctionOp::name());

pir::Operation* op = nullptr;
op = pir::Operation::Create(
vec_inputs, op_attribute, op_output_types, py_func_op_info);
op->set_attribute("origin_id", pir::Int64Attribute::get(ctx, op->id()));

(*map_op_pair)[op_item] = op;

if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
(*map_value_pair)[op_item->result(i)] = op->result(i);
}
}
block->push_back(op);
}

void HandleForTensorRTOp(
pir::IrContext* ctx,
pir::Operation* op_item,
Expand Down Expand Up @@ -3588,7 +3687,6 @@ void ProcessBlock(
auto kernel_name = GetKernelName(op_info_parser.get(), op_item);
auto kernel_key = GetKernelKey(
op_item, place, kernel_name, *map_value_pair, op_info_parser.get());
VLOG(6) << "kernel type " << kernel_key;

if (paddle::dialect::IsCustomOp(op_item)) {
HandleForCustomOp(ctx,
Expand All @@ -3602,6 +3700,18 @@ void ProcessBlock(
continue;
}

if (paddle::dialect::IsPythonOp(op_item)) {
HandleForPythonOp(ctx,
op_item,
kernel_key,
place,
op_info_parser.get(),
map_op_pair,
map_value_pair,
new_block);
continue;
}

if (paddle::dialect::IsTensorRTOp(op_item)) {
HandleForTensorRTOp(ctx,
op_item,
Expand Down Expand Up @@ -3714,6 +3824,7 @@ std::unique_ptr<pir::Program> PdOpLowerToKernelPass(pir::Program* prog,
ctx->GetOrRegisterDialect<OperatorDialect>();
ctx->GetOrRegisterDialect<KernelDialect>();
ctx->GetOrRegisterDialect<CustomKernelDialect>();
ctx->GetOrRegisterDialect<PythonFunctionDialect>();

#ifdef PADDLE_WITH_DNNL
ctx->GetOrRegisterDialect<OneDNNOperatorDialect>();
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ set(PYBIND_SRCS
graph.cc
ir_tensor.cc
ir_meta_tensor.cc
native_meta_tensor.cc
reader_py.cc
protobuf.cc
exception.cc
Expand Down
85 changes: 0 additions & 85 deletions paddle/fluid/pybind/ir_meta_tensor.cc

This file was deleted.

91 changes: 0 additions & 91 deletions paddle/fluid/pybind/ir_tensor.cc

This file was deleted.

22 changes: 0 additions & 22 deletions paddle/fluid/pybind/ir_tensor.h

This file was deleted.

Loading
Loading