33
33
namespace paddle2onnx {
34
34
MapperHelper* MapperHelper::helper = nullptr ;
35
35
int32_t OnnxHelper::opset_version = 7 ;
36
-
37
36
bool ModelExporter::IsOpsRegistered (const PaddlePirParser& pir_parser,
38
37
bool enable_experimental_op) {
39
38
OnnxHelper temp_helper;
@@ -45,6 +44,9 @@ bool ModelExporter::IsOpsRegistered(const PaddlePirParser& pir_parser,
45
44
if (op->name () == " pd_op.if" ) {
46
45
continue ;
47
46
}
47
+ if (op->name () == " pd_op.while" ) {
48
+ continue ;
49
+ }
48
50
std::string op_name = convert_pir_op_name (op->name ());
49
51
if (!MapperHelper::Get ()->IsRegistered (op_name)) {
50
52
unsupported_ops.insert (op_name);
@@ -373,21 +375,22 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportIfBlock(
373
375
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>> temp_parameters;
374
376
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>> temp_inputs;
375
377
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>> temp_outputs;
378
+ std::vector<pir::Operation*> sub_blocks_ops_copy (pir_parser.sub_blocks_ops );
376
379
pir_parser.sub_blocks_ops .clear ();
377
380
for (auto & op : block.ops ()) {
378
381
if (op->name () != " builtin.parameter" ) {
379
382
pir_parser.sub_blocks_ops .push_back (op);
380
383
}
381
384
}
382
- pir_parser.GetALLSubBlockOpOutputName (pir_parser.sub_blocks_ops );
385
+ pir_parser.GetAllSubBlockOpOutputName (pir_parser.sub_blocks_ops );
383
386
if (!pir_parser.sub_blocks_ops .empty ()) {
384
387
// get cf.yeild op input
385
388
pir::Operation* cf_yield_op = pir_parser.sub_blocks_ops .back ();
386
- std::vector<std::string> sub_block_outpus;
389
+ // std::vector<std::string> sub_block_outpus;
387
390
for (auto oprand : cf_yield_op->operands ()) {
388
391
pir::Value value = oprand.source ();
389
392
auto cond_info = pir_parser.GetSubBlockValueTensorInfo (value);
390
- sub_block_outpus.push_back (cond_info[0 ].name );
393
+ // sub_block_outpus.push_back(cond_info[0].name);
391
394
temp_outputs.push_back (std::move (MakeValueInfo (cond_info[0 ])));
392
395
}
393
396
} else {
@@ -400,8 +403,11 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportIfBlock(
400
403
}
401
404
402
405
pir::Block* blockPtr = █
403
- return std::move (ExportBlock (
404
- pir_parser, blockPtr, temp_parameters, temp_inputs, temp_outputs, true ));
406
+ auto graph = std::move (ExportBlock (
407
+ pir_parser, blockPtr, temp_parameters, temp_inputs, temp_outputs, true , false ));
408
+ pir_parser.sub_blocks_ops .clear ();
409
+ pir_parser.sub_blocks_ops = sub_blocks_ops_copy;
410
+ return graph;
405
411
}
406
412
407
413
ONNX_NAMESPACE::GraphProto ModelExporter::ExportConditionalBlock (
@@ -444,9 +450,9 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
444
450
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& parameters,
445
451
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& inputs,
446
452
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& outputs,
447
- bool if_in_subblock) {
453
+ bool if_in_subblock, bool is_while_block ) {
448
454
ONNX_NAMESPACE::GraphProto graph;
449
- graph.set_name (" PaddlePaddle Graph in pir mode" );
455
+ graph.set_name (" PaddlePaddle Graph in PIR mode" );
450
456
OnnxHelper temp_helper;
451
457
std::vector<pir::Operation*> block_ops;
452
458
for (auto & op : block->ops ()) {
@@ -459,11 +465,10 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
459
465
temp_helper.Clear ();
460
466
for (auto i = 0 ; i < num_ops; ++i) {
461
467
auto op = block_ops[i];
462
- if (op->name () == " pd_op.data" || op->name () == " pd_op.fetch" ||
463
- op->name () == " cf.yield" ) {
468
+ if (op->name () == " pd_op.data" || op->name () == " pd_op.fetch" || op->name () == " cf.yield" ) {
464
469
continue ;
465
470
}
466
- if (op->name () == " pd_op.full_int_array" ) {
471
+ if (op->name () == " pd_op.full_int_array" ) { // this is a trick
467
472
bool needExport = false ;
468
473
for (auto it = op->result (0 ).use_begin (); it != op->result (0 ).use_end ();
469
474
++it) {
@@ -489,14 +494,9 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
489
494
// get if op output
490
495
auto num_results = if_op.num_results ();
491
496
std::vector<std::string> if_op_output_name;
492
- if (num_results > 1 ) {
493
- for (int i = 0 ; i < num_results; ++i) {
494
- auto value = if_op.result (i);
495
- auto out_info = pir_parser.GetTensorInfo (value);
496
- if_op_output_name.push_back (out_info[0 ].name );
497
- }
498
- } else {
499
- auto out_info = pir_parser.GetTensorInfo (if_op.result (0 ));
497
+ for (int i = 0 ; i < num_results; ++i) {
498
+ auto value = if_op.result (i);
499
+ auto out_info = pir_parser.GetTensorInfo (value);
500
500
if_op_output_name.push_back (out_info[0 ].name );
501
501
}
502
502
auto node = temp_helper.MakeNode (" If" , {cond_name}, if_op_output_name);
@@ -505,6 +505,7 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
505
505
continue ;
506
506
}
507
507
if (op->name () == " pd_op.while" ) {
508
+ ExportWhile (pir_parser,&temp_helper,op);
508
509
continue ;
509
510
}
510
511
@@ -624,23 +625,23 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
624
625
UpdateParameters (temp_helper.updated_params , parameters);
625
626
}
626
627
627
- for (auto & item : parameters) {
628
+ for (auto & item : parameters) {
628
629
*(graph.add_node ()) = *(item.get ());
629
630
}
630
631
631
- for (auto & item : inputs) {
632
+ for (auto & item : inputs) {
632
633
*(graph.add_input ()) = *(item.get ());
633
634
}
634
635
635
- for (auto & item : outputs) {
636
+ for (auto & item : outputs) {
636
637
*(graph.add_output ()) = (*item.get ());
637
638
}
638
639
639
- for (auto & item : temp_helper.nodes ) {
640
+ for (auto & item : temp_helper.nodes ) {
640
641
*(graph.add_node ()) = (*item.get ());
641
642
}
642
643
643
- for (auto & item : temp_helper.value_infos ) {
644
+ for (auto & item : temp_helper.value_infos ) {
644
645
*(graph.add_value_info ()) = (*item.get ());
645
646
}
646
647
@@ -706,11 +707,11 @@ void ModelExporter::ExportOp(const PaddleParser& parser,
706
707
}
707
708
708
709
void ModelExporter::ProcessGraphDumplicateNames (
709
- std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& parameters,
710
- std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& inputs,
711
- std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& outputs,
712
- std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& nodes,
713
- std::map<std::string, QuantizeInfo>& quantize_info) {
710
+ std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>> & parameters,
711
+ std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>> & inputs,
712
+ std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>> & outputs,
713
+ std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>> & nodes,
714
+ std::map<std::string, QuantizeInfo> & quantize_info) {
714
715
std::map<std::string, std::string> renamer;
715
716
for (auto & item : parameters) {
716
717
for (size_t i = 0 ; i < item->output_size (); ++i) {
@@ -765,7 +766,7 @@ void ModelExporter::ProcessGraphDumplicateNames(
765
766
}
766
767
}
767
768
768
- for (auto & item : outputs) {
769
+ for (auto & item : outputs) {
769
770
if (renamer.find (item->name ()) != renamer.end ()) {
770
771
auto updated_name = renamer[item->name ()];
771
772
while (renamer.find (updated_name) != renamer.end ()) {
@@ -887,6 +888,7 @@ std::string ModelExporter::Run(PaddlePirParser& pir_parser,
887
888
parameters,
888
889
inputs,
889
890
outputs,
891
+ false ,
890
892
false );
891
893
*onnx_model_.mutable_graph () = share_graph;
892
894
if (enable_onnx_checker) {
0 commit comments