Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[OSPP][PIR] support some ops in pir
  • Loading branch information
qzylalala committed Oct 21, 2024
commit f0c4b877119e7326f7b26876390fb3215c48fa40
19 changes: 17 additions & 2 deletions paddle2onnx/mapper/activation/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,26 @@
namespace paddle2onnx {

REGISTER_MAPPER(abs, ActivationMapper)
REGISTER_PIR_MAPPER(abs, ActivationMapper)
REGISTER_MAPPER(acos, ActivationMapper)
REGISTER_MAPPER(asin, ActivationMapper)
REGISTER_MAPPER(atan, ActivationMapper)
REGISTER_MAPPER(brelu, BReluMapper)
REGISTER_MAPPER(ceil, ActivationMapper)
REGISTER_MAPPER(cos, ActivationMapper)
REGISTER_PIR_MAPPER(cos, ActivationMapper)
REGISTER_MAPPER(elu, EluMapper)
REGISTER_MAPPER(erf, ActivationMapper)
REGISTER_MAPPER(exp, ActivationMapper)
REGISTER_PIR_MAPPER(exp, ActivationMapper)
REGISTER_MAPPER(floor, ActivationMapper)
REGISTER_PIR_MAPPER(floor, ActivationMapper)
REGISTER_MAPPER(gelu, GeluMapper)
REGISTER_PIR_MAPPER(gelu, GeluMapper)
REGISTER_MAPPER(leaky_relu, LeakyReluMapper)
REGISTER_PIR_MAPPER(leaky_relu, LeakyReluMapper)
REGISTER_MAPPER(log, ActivationMapper)
REGISTER_PIR_MAPPER(log, ActivationMapper)
REGISTER_MAPPER(log10, Log10Mapper)
REGISTER_MAPPER(log1p, Log1PMapper)
REGISTER_MAPPER(log2, Log2Mapper)
Expand All @@ -45,13 +52,17 @@ REGISTER_MAPPER(rsqrt, RsqrtMapper)
REGISTER_MAPPER(sel, ActivationMapper)
REGISTER_MAPPER(selu, SeluMapper)
REGISTER_MAPPER(silu, SiluMapper)
REGISTER_PIR_MAPPER(silu, SiluMapper)
REGISTER_MAPPER(sin, ActivationMapper)
REGISTER_PIR_MAPPER(sin, ActivationMapper)
REGISTER_MAPPER(size, SizeMapper)
REGISTER_MAPPER(softmax, SoftMaxMapper)
REGISTER_PIR_MAPPER(softmax, SoftMaxMapper)
REGISTER_MAPPER(softplus, ActivationMapper)
REGISTER_MAPPER(softshrink, SoftShrinkMapper)
REGISTER_MAPPER(softsign, ActivationMapper)
REGISTER_MAPPER(sqrt, ActivationMapper)
REGISTER_PIR_MAPPER(sqrt, ActivationMapper)
REGISTER_MAPPER(square, SquareMapper)
REGISTER_MAPPER(tan, ActivationMapper)
REGISTER_MAPPER(tanh, ActivationMapper)
Expand Down Expand Up @@ -85,7 +96,9 @@ void ActivationMapper::Opset7() {
auto output_info = GetOutput("Out");
auto iter = op_mapper_.find(convert_pir_op_name(OpType()));
Assert(op_mapper_.end() != iter,
"Cannot find " + convert_pir_op_name(OpType()) + " in activation op_mapper.");
"Cannot find " +
convert_pir_op_name(OpType()) +
" in activation op_mapper.");
if (convert_pir_op_name(OpType()) == "erf") {
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
Expand Down Expand Up @@ -367,7 +380,9 @@ void ThresholdedReluMapper::Opset10() {
void Log1PMapper::Opset7() {
auto x_info = GetInput("X");
auto out_info = GetOutput("Out");
auto one = helper_->Constant({}, GetOnnxDtype(x_info[0].dtype), float(1.0));
auto one = helper_->Constant({},
GetOnnxDtype(x_info[0].dtype),
static_cast<float>(1.0));
auto input = helper_->MakeNode("Add", {x_info[0].name, one})->output(0);
helper_->MakeNode("Log", {input}, {out_info[0].name});
}
Expand Down
29 changes: 29 additions & 0 deletions paddle2onnx/mapper/activation/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ class LeakyReluMapper : public Mapper {
GetAttr("alpha", &alpha_);
}

LeakyReluMapper(const PaddlePirParser& p, OnnxHelper* helper,
int64_t op_id)
: Mapper(p, helper, op_id) {
in_pir_mode = true;
GetAttr("alpha", &alpha_);
}

void Opset7() override;

private:
Expand All @@ -124,6 +131,12 @@ class GeluMapper : public Mapper {
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}

GeluMapper(const PaddlePirParser& p, OnnxHelper* helper,
int64_t op_id)
: Mapper(p, helper, op_id) {
in_pir_mode = true;
}

int32_t GetMinOpsetVersion(bool verbose) override {
Logger(verbose, 9) << RequireOpset(9) << std::endl;
return 9;
Expand All @@ -144,6 +157,17 @@ class SoftMaxMapper : public Mapper {
}
}

SoftMaxMapper(const PaddlePirParser& p, OnnxHelper* helper,
int64_t op_id)
: Mapper(p, helper, op_id) {
in_pir_mode = true;
if (HasAttr("axis")) {
GetAttr("axis", &axis_);
} else {
axis_ = -1;
}
}

void Opset7() override;
void Opset13() override;

Expand Down Expand Up @@ -310,6 +334,11 @@ class SiluMapper : public Mapper {
SiluMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
SiluMapper(const PaddlePirParser& p, OnnxHelper* helper,
int64_t op_id)
: Mapper(p, helper, op_id) {
in_pir_mode = true;
}
void Opset7() override;
};

Expand Down
3 changes: 2 additions & 1 deletion paddle2onnx/mapper/activation/sigmoid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

namespace paddle2onnx {
REGISTER_MAPPER(sigmoid, SigmoidMapper)
REGISTER_PIR_MAPPER(sigmoid, SigmoidMapper)

void SigmoidMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
helper_->MakeNode("Sigmoid", {input_info[0].name}, {output_info[0].name});
}
}
}
12 changes: 8 additions & 4 deletions paddle2onnx/mapper/activation/sigmoid.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,24 @@
// limitations under the License.
#pragma once


#include "paddle2onnx/mapper/mapper.h"

#include <cmath>
#include <map>
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {
class SigmoidMapper : public Mapper {
public:
SigmoidMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
SigmoidMapper(const PaddlePirParser& p, OnnxHelper* helper,
int64_t op_id)
: Mapper(p, helper, op_id) {
in_pir_mode = true;
}
void Opset7() override;
};
}
} // namespace paddle2onnx
14 changes: 10 additions & 4 deletions paddle2onnx/mapper/activation/swish.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

namespace paddle2onnx {
REGISTER_MAPPER(swish, SwishMapper)
REGISTER_PIR_MAPPER(swish, SwishMapper)

void SwishMapper::Opset7() {
auto input_info = GetInput("X");
Expand All @@ -25,13 +26,18 @@ void SwishMapper::Opset7() {
if (HasAttr("beta")) {
float temp_beta = 1.0;
GetAttr("beta", &temp_beta);
std::string beta_node = helper_->Constant({}, GetOnnxDtype(input_info[0].dtype), temp_beta);
auto beta_x_node = helper_->MakeNode("Mul", {input_info[0].name, beta_node});
std::string beta_node = helper_->Constant({},
GetOnnxDtype(input_info[0].dtype),
temp_beta);
auto beta_x_node = helper_->MakeNode("Mul",
{input_info[0].name, beta_node});
sigmod_node = helper_->MakeNode("Sigmoid", {beta_x_node->output(0)});
} else {
sigmod_node = helper_->MakeNode("Sigmoid", {input_info[0].name});
}

helper_->MakeNode("Mul", {input_info[0].name, sigmod_node->output(0)}, {output_info[0].name});
helper_->MakeNode("Mul",
{input_info[0].name, sigmod_node->output(0)},
{output_info[0].name});
}
}
} // namespace paddle2onnx
11 changes: 8 additions & 3 deletions paddle2onnx/mapper/activation/swish.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,24 @@
#pragma once


#include "paddle2onnx/mapper/mapper.h"

#include <cmath>
#include <map>
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {
class SwishMapper : public Mapper {
public:
SwishMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
SwishMapper(const PaddlePirParser& p, OnnxHelper* helper,
int64_t op_id)
: Mapper(p, helper, op_id) {
in_pir_mode = true;
}
void Opset7() override;
};
}
} // namespace paddle2onnx
32 changes: 18 additions & 14 deletions paddle2onnx/mapper/exporter.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,24 @@
#endif

inline std::string convert_pir_op_name(const std::string pir_op_name) {
std::unordered_map<std::string, std::string> op_name_mappings = {
{"matmul", "matmul_v2"},
// {"relu", "relu6"},
{"batch_norm_", "batch_norm"},
{"flatten", "flatten_contiguous_range"},
{"add", "elementwise_add"}};
std::unordered_map<std::string, std::string> op_name_mappings = {
{"matmul", "matmul_v2"},
// {"relu", "relu6"},
{"batch_norm_", "batch_norm"},
{"assign_value_", "assign_value"},
{"flatten", "flatten_contiguous_range"},
{"add", "elementwise_add"}};
std::string op_name = pir_op_name;
std::string prefix = "pd_op.";
std::string builtin_prefix = "builtin.";

size_t prefix_pos = op_name.find(prefix);
if (prefix_pos != std::string::npos) {
op_name = op_name.substr(prefix_pos + prefix.size());
}
else {
if(op_name.substr(0, builtin_prefix.size()) == builtin_prefix) {
} else {
if(op_name.substr(0, builtin_prefix.size()) == builtin_prefix) {
op_name[builtin_prefix.size() - 1] = '_';
}
}
}
auto it = op_name_mappings.find(op_name);
if (it != op_name_mappings.end()) {
Expand Down Expand Up @@ -162,8 +162,10 @@ class ModelExporter {
&inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
&outputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>> &nodes,
std::map<std::string, QuantizeInfo> &quantize_info);
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
&nodes,
std::map<std::string, QuantizeInfo>
&quantize_info);
// Update constant node in parameters. When process quantize model, the weight
// dtype may be int8, it should be convet to float32 and use this function to
// update converted params.
Expand All @@ -181,15 +183,17 @@ class ModelExporter {
ONNX_NAMESPACE::GraphProto ExportBlock(
const PaddleParser &parser,
int32_t block_id,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>> &parameters,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
&parameters,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
&inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
&outputs);

ONNX_NAMESPACE::GraphProto ExportBlock(
const PaddlePirParser &pir_parser,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>> &parameters,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>
&parameters,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
&inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>
Expand Down
2 changes: 2 additions & 0 deletions paddle2onnx/mapper/nn/conv2d_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

namespace paddle2onnx {
REGISTER_MAPPER(conv2d_transpose, Conv2dTransposeMapper)
REGISTER_PIR_MAPPER(conv2d_transpose, Conv2dTransposeMapper)
REGISTER_MAPPER(depthwise_conv2d_transpose, Conv2dTransposeMapper)
REGISTER_PIR_MAPPER(depthwise_conv2d_transpose, Conv2dTransposeMapper)

int32_t Conv2dTransposeMapper::GetMinOpsetVersion(bool verbose) {
// NHWC is not supported
Expand Down
29 changes: 28 additions & 1 deletion paddle2onnx/mapper/nn/conv2d_transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Conv2dTransposeMapper : public Mapper {
GetAttr("padding_algorithm", &padding_algorithm_);
GetAttr("data_format", &data_format_);

if (HasAttr("output_padding")){
if (HasAttr("output_padding")) {
GetAttr("output_padding", &output_padding_);
}
GetAttr("output_size", &output_size_);
Expand All @@ -46,6 +46,33 @@ class Conv2dTransposeMapper : public Mapper {
}
}

Conv2dTransposeMapper(const PaddlePirParser& p, OnnxHelper* helper,
int64_t op_id)
: Mapper(p, helper, op_id) {
in_pir_mode = true;
GetAttr("groups", &groups_);
GetAttr("dilations", &dilations_);
GetAttr("strides", &strides_);
GetAttr("paddings", &paddings_);
GetAttr("padding_algorithm", &padding_algorithm_);
GetAttr("data_format", &data_format_);

if (HasAttr("output_padding")) {
GetAttr("output_padding", &output_padding_);
}
if (HasAttr("output_size")) {
GetAttr("output_size", &output_size_);
}
if (paddings_.size() == 2) {
paddings_.push_back(paddings_[0]);
paddings_.push_back(paddings_[1]);
} else if (paddings_.size() == 4) {
int32_t tmp = paddings_[1];
paddings_[1] = paddings_[2];
paddings_[2] = tmp;
}
}

int32_t GetMinOpsetVersion(bool verbose) override;
void Opset7() override;

Expand Down
1 change: 1 addition & 0 deletions paddle2onnx/mapper/nn/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

namespace paddle2onnx {
REGISTER_MAPPER(group_norm, GroupNormMapper)
REGISTER_PIR_MAPPER(group_norm, GroupNormMapper)

int32_t GroupNormMapper::GetMinOpsetVersion(bool verbose) {
auto input_info = GetInput("X");
Expand Down
7 changes: 7 additions & 0 deletions paddle2onnx/mapper/nn/group_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ class GroupNormMapper : public Mapper {
GetAttr("groups", &groups_);
GetAttr("epsilon", &epsilon_);
}
GroupNormMapper(const PaddlePirParser& p, OnnxHelper* helper,
int64_t op_id)
: Mapper(p, helper, op_id) {
in_pir_mode = true;
GetAttr("groups", &groups_);
GetAttr("epsilon", &epsilon_);
}

int32_t GetMinOpsetVersion(bool verbose) override;
void Opset7() override;
Expand Down
1 change: 1 addition & 0 deletions paddle2onnx/mapper/nn/shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

namespace paddle2onnx {
REGISTER_MAPPER(shape, ShapeMapper)
REGISTER_PIR_MAPPER(shape, ShapeMapper)

void ShapeMapper::Opset7() {
auto input_info = GetInput("Input");
Expand Down
5 changes: 5 additions & 0 deletions paddle2onnx/mapper/nn/shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ class ShapeMapper : public Mapper {
ShapeMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
ShapeMapper(const PaddlePirParser& p, OnnxHelper* helper,
int64_t op_id)
: Mapper(p, helper, op_id) {
in_pir_mode = true;
}

void Opset7() override;
};
Expand Down
Loading