Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 112 additions & 1 deletion paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2505,6 +2505,105 @@ void HandleForCustomOp(
block->push_back(op);
}

void HandleForPythonOp(
pir::IrContext* ctx,
pir::Operation* op_item,
const phi::KernelKey& kernel_key,
const phi::Place place,
const OpYamlInfoParser* op_info_parser,
std::unordered_map<pir::Operation*, pir::Operation*>* map_op_pair,
std::unordered_map<pir::Value, pir::Value>* map_value_pair,
pir::Block* block) {
// Prepare output
std::vector<pir::Type> op_output_types;
for (size_t i = 0; i < op_item->num_results(); ++i) {
phi::Place out_place = phi::TransToPhiPlace(kernel_key.backend());
PushBackOutputTypes(ctx,
op_item,
op_item->result(i).type(),
out_place,
kernel_key,
&op_output_types);
}

// Prepare input
std::vector<pir::Value> vec_inputs;
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
PADDLE_ENFORCE_EQ(
map_value_pair->count(cur_in),
true,
common::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST in map pair", i, op_item->name()));

auto new_in = map_value_pair->at(cur_in);
auto new_in_type = new_in.type();

if (new_in_type.isa<AllocatedDenseTensorType>()) {
auto in_place = new_in_type.dyn_cast<AllocatedDenseTensorType>().place();
// need trans from GPU_PINNED to GPU, refer to PR#41972
if (phi::AllocationType::GPUPINNED == place.GetType()) {
// build memcopy op
auto out_place = phi::TransToPhiPlace(phi::Backend::GPU);
auto new_in_alloc_type =
new_in_type.dyn_cast<AllocatedDenseTensorType>();
auto out_type =
AllocatedDenseTensorType::get(ctx,
out_place,
new_in_alloc_type.dtype(),
new_in_alloc_type.dims(),
new_in_alloc_type.data_layout(),
new_in_alloc_type.lod(),
new_in_alloc_type.offset());
new_in = AddPlaceTransferOp(
new_in, out_type, in_place, out_place, kernel_key, block);
}
}

vec_inputs.push_back(new_in);
}

// Prepare attr
std::unordered_map<std::string, pir::Attribute> op_attribute{
{"op_name", pir::StrAttribute::get(ctx, op_item->name())},
{"kernel_name", pir::StrAttribute::get(ctx, op_item->name())},
{"kernel_key", KernelAttribute::get(ctx, kernel_key)}};

auto op_attr_map = op_item->attributes();
for (auto& map_item : op_attr_map) {
op_attribute.emplace(map_item.first, map_item.second);
}
if (op_item->HasTrait<InplaceTrait>()) {
op_attribute.emplace("is_inplace", pir::BoolAttribute::get(ctx, true));
}
op_attribute.emplace("origin_id",
pir::Int64Attribute::get(ctx, op_item->id()));

VLOG(6) << "Lower custom pyop: " << op_item->name()
<< " to : " << PythonFunctionOp::name();

pir::OpInfo py_func_op_info =
ctx->GetRegisteredOpInfo(PythonFunctionOp::name());

pir::Operation* op = nullptr;
op = pir::Operation::Create(
vec_inputs, op_attribute, op_output_types, py_func_op_info);
op->set_attribute("origin_id", pir::Int64Attribute::get(ctx, op->id()));

(*map_op_pair)[op_item] = op;

if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
(*map_value_pair)[op_item->result(i)] = op->result(i);
}
}
block->push_back(op);
}

void HandleForTensorRTOp(
pir::IrContext* ctx,
pir::Operation* op_item,
Expand Down Expand Up @@ -3588,7 +3687,6 @@ void ProcessBlock(
auto kernel_name = GetKernelName(op_info_parser.get(), op_item);
auto kernel_key = GetKernelKey(
op_item, place, kernel_name, *map_value_pair, op_info_parser.get());
VLOG(6) << "kernel type " << kernel_key;

if (paddle::dialect::IsCustomOp(op_item)) {
HandleForCustomOp(ctx,
Expand All @@ -3602,6 +3700,18 @@ void ProcessBlock(
continue;
}

if (paddle::dialect::IsCustomPyOp(op_item)) {
HandleForPythonOp(ctx,
op_item,
kernel_key,
place,
op_info_parser.get(),
map_op_pair,
map_value_pair,
new_block);
continue;
}

if (paddle::dialect::IsTensorRTOp(op_item)) {
HandleForTensorRTOp(ctx,
op_item,
Expand Down Expand Up @@ -3714,6 +3824,7 @@ std::unique_ptr<pir::Program> PdOpLowerToKernelPass(pir::Program* prog,
ctx->GetOrRegisterDialect<OperatorDialect>();
ctx->GetOrRegisterDialect<KernelDialect>();
ctx->GetOrRegisterDialect<CustomKernelDialect>();
ctx->GetOrRegisterDialect<PythonFunctionDialect>();

#ifdef PADDLE_WITH_DNNL
ctx->GetOrRegisterDialect<OneDNNOperatorDialect>();
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ set(PYBIND_SRCS
graph.cc
ir_tensor.cc
ir_meta_tensor.cc
native_meta_tensor.cc
reader_py.cc
protobuf.cc
exception.cc
Expand Down
85 changes: 0 additions & 85 deletions paddle/fluid/pybind/ir_meta_tensor.cc

This file was deleted.

91 changes: 0 additions & 91 deletions paddle/fluid/pybind/ir_tensor.cc

This file was deleted.

22 changes: 0 additions & 22 deletions paddle/fluid/pybind/ir_tensor.h

This file was deleted.

Loading
Loading