Skip to content

Commit 8cbe9fc

Browse files
authored
Merge pull request #1426 from 0x3878f/pir_develop_6
【PIR】Support while op in pir mode
2 parents e33cd3b + b31d8bb commit 8cbe9fc

File tree

9 files changed

+339
-90
lines changed

9 files changed

+339
-90
lines changed

paddle2onnx/mapper/exporter.cc

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
namespace paddle2onnx {
3434
MapperHelper* MapperHelper::helper = nullptr;
3535
int32_t OnnxHelper::opset_version = 7;
36-
3736
bool ModelExporter::IsOpsRegistered(const PaddlePirParser& pir_parser,
3837
bool enable_experimental_op) {
3938
OnnxHelper temp_helper;
@@ -45,6 +44,9 @@ bool ModelExporter::IsOpsRegistered(const PaddlePirParser& pir_parser,
4544
if (op->name() == "pd_op.if") {
4645
continue;
4746
}
47+
if (op->name() == "pd_op.while") {
48+
continue;
49+
}
4850
std::string op_name = convert_pir_op_name(op->name());
4951
if (!MapperHelper::Get()->IsRegistered(op_name)) {
5052
unsupported_ops.insert(op_name);
@@ -373,21 +375,22 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportIfBlock(
373375
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>> temp_parameters;
374376
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>> temp_inputs;
375377
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>> temp_outputs;
378+
std::vector<pir::Operation*> sub_blocks_ops_copy(pir_parser.sub_blocks_ops);
376379
pir_parser.sub_blocks_ops.clear();
377380
for (auto& op : block.ops()) {
378381
if (op->name() != "builtin.parameter") {
379382
pir_parser.sub_blocks_ops.push_back(op);
380383
}
381384
}
382-
pir_parser.GetALLSubBlockOpOutputName(pir_parser.sub_blocks_ops);
385+
pir_parser.GetAllSubBlockOpOutputName(pir_parser.sub_blocks_ops);
383386
if (!pir_parser.sub_blocks_ops.empty()) {
384387
// get cf.yeild op input
385388
pir::Operation* cf_yield_op = pir_parser.sub_blocks_ops.back();
386-
std::vector<std::string> sub_block_outpus;
389+
// std::vector<std::string> sub_block_outpus;
387390
for (auto oprand : cf_yield_op->operands()) {
388391
pir::Value value = oprand.source();
389392
auto cond_info = pir_parser.GetSubBlockValueTensorInfo(value);
390-
sub_block_outpus.push_back(cond_info[0].name);
393+
// sub_block_outpus.push_back(cond_info[0].name);
391394
temp_outputs.push_back(std::move(MakeValueInfo(cond_info[0])));
392395
}
393396
} else {
@@ -400,8 +403,11 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportIfBlock(
400403
}
401404

402405
pir::Block* blockPtr = &block;
403-
return std::move(ExportBlock(
404-
pir_parser, blockPtr, temp_parameters, temp_inputs, temp_outputs, true));
406+
auto graph = std::move(ExportBlock(
407+
pir_parser, blockPtr, temp_parameters, temp_inputs, temp_outputs, true, false));
408+
pir_parser.sub_blocks_ops.clear();
409+
pir_parser.sub_blocks_ops = sub_blocks_ops_copy;
410+
return graph;
405411
}
406412

407413
ONNX_NAMESPACE::GraphProto ModelExporter::ExportConditionalBlock(
@@ -444,9 +450,9 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
444450
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& parameters,
445451
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& inputs,
446452
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& outputs,
447-
bool if_in_subblock) {
453+
bool if_in_subblock, bool is_while_block) {
448454
ONNX_NAMESPACE::GraphProto graph;
449-
graph.set_name("PaddlePaddle Graph in pir mode");
455+
graph.set_name("PaddlePaddle Graph in PIR mode");
450456
OnnxHelper temp_helper;
451457
std::vector<pir::Operation*> block_ops;
452458
for (auto& op : block->ops()) {
@@ -459,11 +465,10 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
459465
temp_helper.Clear();
460466
for (auto i = 0; i < num_ops; ++i) {
461467
auto op = block_ops[i];
462-
if (op->name() == "pd_op.data" || op->name() == "pd_op.fetch" ||
463-
op->name() == "cf.yield") {
468+
if (op->name() == "pd_op.data" || op->name() == "pd_op.fetch" || op->name() == "cf.yield") {
464469
continue;
465470
}
466-
if (op->name() == "pd_op.full_int_array") {
471+
if (op->name() == "pd_op.full_int_array") { // this is a trick
467472
bool needExport = false;
468473
for (auto it = op->result(0).use_begin(); it != op->result(0).use_end();
469474
++it) {
@@ -489,14 +494,9 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
489494
// get if op output
490495
auto num_results = if_op.num_results();
491496
std::vector<std::string> if_op_output_name;
492-
if (num_results > 1) {
493-
for (int i = 0; i < num_results; ++i) {
494-
auto value = if_op.result(i);
495-
auto out_info = pir_parser.GetTensorInfo(value);
496-
if_op_output_name.push_back(out_info[0].name);
497-
}
498-
} else {
499-
auto out_info = pir_parser.GetTensorInfo(if_op.result(0));
497+
for (int i = 0; i < num_results; ++i) {
498+
auto value = if_op.result(i);
499+
auto out_info = pir_parser.GetTensorInfo(value);
500500
if_op_output_name.push_back(out_info[0].name);
501501
}
502502
auto node = temp_helper.MakeNode("If", {cond_name}, if_op_output_name);
@@ -505,6 +505,7 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
505505
continue;
506506
}
507507
if (op->name() == "pd_op.while") {
508+
ExportWhile(pir_parser,&temp_helper,op);
508509
continue;
509510
}
510511

@@ -624,23 +625,23 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
624625
UpdateParameters(temp_helper.updated_params, parameters);
625626
}
626627

627-
for (auto& item : parameters) {
628+
for (auto &item : parameters) {
628629
*(graph.add_node()) = *(item.get());
629630
}
630631

631-
for (auto& item : inputs) {
632+
for (auto &item : inputs) {
632633
*(graph.add_input()) = *(item.get());
633634
}
634635

635-
for (auto& item : outputs) {
636+
for (auto &item : outputs) {
636637
*(graph.add_output()) = (*item.get());
637638
}
638639

639-
for (auto& item : temp_helper.nodes) {
640+
for (auto &item : temp_helper.nodes) {
640641
*(graph.add_node()) = (*item.get());
641642
}
642643

643-
for (auto& item : temp_helper.value_infos) {
644+
for (auto &item : temp_helper.value_infos) {
644645
*(graph.add_value_info()) = (*item.get());
645646
}
646647

@@ -706,11 +707,11 @@ void ModelExporter::ExportOp(const PaddleParser& parser,
706707
}
707708

708709
void ModelExporter::ProcessGraphDumplicateNames(
709-
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& parameters,
710-
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& inputs,
711-
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& outputs,
712-
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& nodes,
713-
std::map<std::string, QuantizeInfo>& quantize_info) {
710+
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>> &parameters,
711+
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>> &inputs,
712+
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>> &outputs,
713+
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>> &nodes,
714+
std::map<std::string, QuantizeInfo> &quantize_info) {
714715
std::map<std::string, std::string> renamer;
715716
for (auto& item : parameters) {
716717
for (size_t i = 0; i < item->output_size(); ++i) {
@@ -765,7 +766,7 @@ void ModelExporter::ProcessGraphDumplicateNames(
765766
}
766767
}
767768

768-
for (auto& item : outputs) {
769+
for (auto &item : outputs) {
769770
if (renamer.find(item->name()) != renamer.end()) {
770771
auto updated_name = renamer[item->name()];
771772
while (renamer.find(updated_name) != renamer.end()) {
@@ -887,6 +888,7 @@ std::string ModelExporter::Run(PaddlePirParser& pir_parser,
887888
parameters,
888889
inputs,
889890
outputs,
891+
false,
890892
false);
891893
*onnx_model_.mutable_graph() = share_graph;
892894
if (enable_onnx_checker) {

paddle2onnx/mapper/exporter.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,16 @@ class ModelExporter {
128128
bool* save_external = nullptr,
129129
bool export_fp16_model = false,
130130
std::vector<std::string> disable_fp16_op_types = {});
131-
132131
private:
133132
bool verbose_ = false;
134133
// The _deploy_backend will pass to Mapper to influence the conversion
135134
std::string deploy_backend_ = "onnxruntime";
136135
std::string* calibration_cache_ = nullptr;
137136
int32_t opset_version_ = 7;
138137

138+
void ExportWhile(PaddlePirParser& pir_parser,
139+
OnnxHelper* temp_helper,
140+
pir::Operation* op);
139141
bool IsOpsRegistered(const PaddleParser& parser, bool enable_experimental_op);
140142
bool IsOpsRegistered(const PaddlePirParser& parser,
141143
bool enable_experimental_op);
@@ -219,13 +221,10 @@ class ModelExporter {
219221
PaddlePirParser
220222
& pir_parser,
221223
pir::Block* block,
222-
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
223-
& parameters,
224-
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
225-
& inputs,
226-
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
227-
& outputs,
228-
bool if_in_subblock);
224+
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& parameters,
225+
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& inputs,
226+
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& outputs,
227+
bool if_in_subblock,bool is_while_block);
229228

230229
void ExportOp(const PaddleParser& parser,
231230
OnnxHelper* helper,

paddle2onnx/mapper/mapper.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,10 @@ class Mapper {
263263
if (in_pir_mode) {
264264
if (if_in_cf_block) {
265265
auto op = pir_parser_->sub_blocks_ops[pir_op_idx_];
266-
return pir_parser_->OpHasAttr(op, name, true);
266+
return pir_parser_->OpHasAttr(op, name);
267267
} else {
268268
auto op = pir_parser_->global_blocks_ops[pir_op_idx_];
269-
return pir_parser_->OpHasAttr(op, name, false);
269+
return pir_parser_->OpHasAttr(op, name);
270270
}
271271
} else {
272272
auto& op = parser_->GetOpDesc(block_idx_, op_idx_);

paddle2onnx/mapper/tensor/less_equal.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
namespace paddle2onnx {
1818
REGISTER_MAPPER(less_equal, LessEqualMapper)
19+
REGISTER_PIR_MAPPER(less_equal, LessEqualMapper)
1920

2021
void LessEqualMapper::Opset7() {
2122
auto x_info = GetInput("X");

paddle2onnx/mapper/tensor/less_equal.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ class LessEqualMapper : public Mapper {
2222
LessEqualMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
2323
int64_t op_id)
2424
: Mapper(p, helper, block_id, op_id) {}
25+
26+
LessEqualMapper(const PaddlePirParser& p, OnnxHelper* helper, int64_t i, bool c)
27+
: Mapper(p, helper, i, c) {}
2528
void Opset7() override;
2629
void Opset12() override;
2730
};

paddle2onnx/mapper/while.cc

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle2onnx/mapper/exporter.h"
16+
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
17+
namespace paddle2onnx {
18+
void ModelExporter::ExportWhile(PaddlePirParser& pir_parser,OnnxHelper* temp_helper,pir::Operation* op) {
19+
// ================================
20+
// construct loop body sub graph
21+
// ================================
22+
std::vector<TensorInfo> inputs_info;
23+
std::vector<TensorInfo> outputs_info;
24+
auto while_op = op->dyn_cast<paddle::dialect::WhileOp>();
25+
auto cond_info = pir_parser.GetTensorInfo(while_op.cond());
26+
// mapping args and inputs in while op using while_op_input_value_map
27+
std::vector<pir::detail::ValueImpl*> while_op_input_value_address;
28+
std::vector<pir::detail::ValueImpl*> while_op_input_arg_address;
29+
pir_parser.while_op_input_value_map.clear(); // wangmingkai02: handle nested loop situations in future.
30+
31+
// record input value address
32+
for(int index = 1; index < while_op.num_operands(); index++){
33+
const pir::Value& value = while_op.operand_source(index);
34+
inputs_info.push_back(pir_parser.GetTensorInfo(pir_parser.GetOpOutputName(value), value.type()));
35+
while_op_input_value_address.push_back(&(*(value).impl())); // get value address
36+
}
37+
// record args value address
38+
std::vector<pir::Value> args = while_op.block_args();
39+
for(int i = 0; i< args.size(); i++){
40+
const pir::Value& value = args[i];
41+
while_op_input_arg_address.push_back(&(*(value.impl())));
42+
}
43+
44+
// mapping
45+
for(int index=0; index < while_op_input_value_address.size(); index++){
46+
pir_parser.while_op_input_value_map[while_op_input_arg_address[index]] = while_op_input_value_address[index];
47+
}
48+
49+
std::vector<pir::Operation*> sub_blocks_ops_copy(pir_parser.sub_blocks_ops);
50+
pir_parser.sub_blocks_ops.clear();
51+
auto& body_block = while_op.body();
52+
for (auto& op : body_block.ops()) {
53+
if (op->name() != "builtin.parameter") {
54+
pir_parser.sub_blocks_ops.push_back(op);
55+
}
56+
}
57+
58+
pir_parser.GetAllSubBlockOpOutputName(pir_parser.sub_blocks_ops);
59+
if (!pir_parser.sub_blocks_ops.empty()) {
60+
// get cf.yeild op input
61+
pir::Operation* cf_yield_op = pir_parser.sub_blocks_ops.back();
62+
PADDLE_ENFORCE_EQ(cf_yield_op->name(),
63+
"cf.yield",
64+
::common::errors::InvalidArgument(
65+
"The last op of a control flow sub-block must be cf.yield"));
66+
for (auto oprand : cf_yield_op->operands()) {
67+
pir::Value value = oprand.source();
68+
auto info = pir_parser.GetSubBlockValueTensorInfo(value);
69+
outputs_info.push_back(info[0]);
70+
}
71+
72+
} else {
73+
// sub_blocks_ops is empty
74+
PADDLE_ENFORCE_NE(pir_parser.sub_blocks_ops.size(),
75+
0,
76+
::common::errors::InvalidArgument(
77+
"The number of ops of a control flow sub-block "
78+
"cannot be zero."));
79+
}
80+
81+
ONNX_NAMESPACE::GraphProto graph;
82+
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>> parameters;
83+
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>> inputs;
84+
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>> outputs;
85+
auto iter_name = MapperHelper::Get()->GenName("loop.iter");
86+
TensorInfo iter_info(iter_name, std::vector<int64_t>(1, 1),
87+
P2ODataType::INT64);
88+
// inputs
89+
inputs.push_back(std::move(MakeValueInfo(iter_info)));
90+
inputs.push_back(std::move(MakeValueInfo(cond_info[0])));
91+
for (size_t i = 0; i < inputs_info.size(); ++i) {
92+
inputs.push_back(std::move(MakeValueInfo(inputs_info[i])));
93+
}
94+
// outputs
95+
for (size_t i = 0; i < outputs_info.size(); ++i) {
96+
outputs.push_back(std::move(MakeValueInfo(outputs_info[i])));
97+
}
98+
pir::Block* blockPtr = &body_block;
99+
graph = ExportBlock(pir_parser, blockPtr, parameters, inputs, outputs, true, true);
100+
pir_parser.sub_blocks_ops.clear();
101+
pir_parser.sub_blocks_ops = sub_blocks_ops_copy;
102+
103+
// =====================
104+
// construct loop node
105+
// =====================
106+
std::vector<std::string> input_names;
107+
std::vector<std::string> output_names;
108+
input_names.push_back(""); // skip max loop iter
109+
input_names.push_back(cond_info[0].name);
110+
for(size_t i = 0; i < inputs_info.size(); ++i) {
111+
input_names.push_back(inputs_info[i].name);
112+
}
113+
for(size_t i = 0; i < op->num_results(); i++) {
114+
output_names.push_back(pir_parser.GetOpOutputName(op->result(i)));
115+
}
116+
auto loop_node = temp_helper->MakeNode("Loop", input_names, output_names);
117+
AddAttribute(loop_node, "body", graph);
118+
}
119+
} // namespace paddle2onnx

0 commit comments

Comments
 (0)