Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge remote-tracking branch 'upstream/test_pir' into dev
  • Loading branch information
qzylalala committed Oct 21, 2024
commit dc5e3891c0c05f315d5db91d49b9f82e8898c9c9
70 changes: 41 additions & 29 deletions paddle2onnx/mapper/exporter.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ inline std::string convert_pir_op_name(const std::string pir_op_name) {
if (prefix_pos != std::string::npos) {
op_name = op_name.substr(prefix_pos + prefix.size());
} else {
if(op_name.substr(0, builtin_prefix.size()) == builtin_prefix) {
op_name[builtin_prefix.size() - 1] = '_';
}
if (op_name.substr(0, builtin_prefix.size()) == builtin_prefix) {
op_name[builtin_prefix.size() - 1] = '_';
}
}
auto it = op_name_mappings.find(op_name);
if (it != op_name_mappings.end()) {
Expand Down Expand Up @@ -95,7 +95,8 @@ class ModelExporter {
bool* save_external = nullptr,
bool export_fp16_model = false,
std::vector<std::string> disable_fp16_op_types = {});
std::string Run(PaddlePirParser& pir_parser,
std::string Run(PaddlePirParser
& pir_parser,
int opset_version = 9,
bool auto_upgrade_opset = true,
bool verbose = false,
Expand Down Expand Up @@ -133,39 +134,46 @@ class ModelExporter {
//
void ExportInputOutputs(
const PaddleParser& parser,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& outputs);
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
& outputs);

void ExportInputOutputs(
const PaddlePirParser& pir_parser,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& outputs);
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
& outputs);

void ExportParameters(
const PaddleParser& parser,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& parameters);
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
& parameters);
void ExportParameters(
const PaddlePirParser& pir_parser,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& parameters);
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
& parameters);
// Process dumplicate tensor names in paddle model
std::set<std::string> tensor_names_;
void ProcessGraphDumplicateNames(
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
&parameters,
& parameters,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
&inputs,
& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
&outputs,
& outputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
&nodes,
& nodes,
std::map<std::string, QuantizeInfo>
&quantize_info);
// Update constant node in parameters. When process quantize model, the weight
// dtype may be int8, it should be convet to float32 and use this function to
// update converted params.
& quantize_info);
// Update constant node in parameters. When process quantize model, the
// weight dtype may be int8, it should be convet to float32 and use this
// function to update converted params.
void UpdateParameters(
const std::map<std::string, Weight>& params,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& parameters);
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
& parameters);
//
std::map<std::string, std::pair<int32_t, int32_t>> sub_block_map_;
ONNX_NAMESPACE::GraphProto ExportConditionalBlock(
Expand All @@ -174,27 +182,31 @@ class ModelExporter {
int32_t op_id,
const std::string& output_names);

ONNX_NAMESPACE::GraphProto ExportIfBlock(PaddlePirParser& pir_parser,
pir::Block& block);
ONNX_NAMESPACE::GraphProto ExportIfBlock(PaddlePirParser
& pir_parser,
pir::Block
& block);

ONNX_NAMESPACE::GraphProto ExportBlock(
const PaddleParser& parser,
int32_t block_id,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
&parameters,
& parameters,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
&inputs,
& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
&outputs);

& outputs);
ONNX_NAMESPACE::GraphProto ExportBlock(
const PaddlePirParser &pir_parser,
PaddlePirParser
& pir_parser,
pir::Block* block,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
&parameters,
& parameters,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
&inputs,
& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
&outputs);
& outputs,
bool if_in_subblock);

void ExportOp(const PaddleParser& parser,
OnnxHelper* helper,
Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.