forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph_transformer.h
More file actions
141 lines (112 loc) · 5.14 KB
/
graph_transformer.h
File metadata and controls
141 lines (112 loc) · 5.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/graph/graph_viewer.h"
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
/**
@class GraphTransformer
The interface for in-place transformation of a Graph.
*/
class GraphTransformer {
public:
GraphTransformer(const std::string& name, const std::string& desc)
: name_(name), desc_(desc) {
}
virtual ~GraphTransformer() = default;
/** Gets the name of this graph transformer. */
const std::string& Name() const noexcept {
return name_;
}
/** Gets the description of this graph transformer. */
const std::string& Description() const noexcept {
return desc_;
}
/** Apply the in-place transformation defined by this transformer to the provided Graph instance.
@param[out] modified Set to true if the Graph was modified.
@returns Status with success or error information.
*/
common::Status Apply(Graph& graph, bool& modified) const;
protected:
/** Helper method to call ApplyImpl on any subgraphs in the Node. */
common::Status Recurse(Node& node, bool& modified, int graph_level) const {
int subgraph_level = ++graph_level;
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
auto& subgraph = *entry.second;
ORT_RETURN_IF_ERROR(ApplyImpl(subgraph, modified, subgraph_level));
}
return Status::OK();
}
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphTransformer);
// Apply the transform to the graph.
// graph_level is 0 for the main graph, and is incremented when descending into the subgraph of a node.
// You MUST call Recurse for all valid Nodes in the graph to ensure any subgraphs in control flow nodes
// (Scan/If/Loop) are processed as well.
// You should avoid calling Graph::Resolve in ApplyImpl unless you are 100% sure it's required. In most cases
// the call to Graph::Resolve in Apply prior to ApplyImpl being called, and after ApplyImpl fore the main graph
// completes (if 'modified' is true) should suffice.
virtual common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level = 0) const = 0;
const std::string name_;
const std::string desc_;
};
/**
@class RuleBasedGraphTransformer
Rule based graph transformer that provides an API to register rewrite rules,
and an API to apply all applicable rules to a Graph.
Represents an IGraphTransformer determined by a set of rewrite-rules.
The transformer will apply all the rewrite-rules iteratively as determined by the underlying rewriting-strategy.
Several rewriting-strategies are possible when traversing the graph and applying rewrite rules,
each with different trade offs. At the moment, we define one that performs top-down traversal of nodes.
@TODO: Is a bottom-up traversal more efficient?
@TODO: Is it worth adding the max number of passes a rule should be applied for?
@TODO: We need to define a contract about whether a rewrite rule is allowed to leave
the graph in an inconsistent state (this will determine when and where we will be
calling Graph::resolve().
*/
class RuleBasedGraphTransformer : public GraphTransformer {
public:
RuleBasedGraphTransformer(const std::string& name, const std::string& desc)
: GraphTransformer(name, desc) {}
/**
Register a rewriting rule.
@TODO (revisit needed): Using OpSignature* here will ask that OpSignature should be stored globally.
Otherwise, there will be multiple addresses/pointers for the same operator or function.
To avoid this, we may use OpSignature ID as the key, which should be name_domain_version.
We will use the string type instead of the OpSchema for now. We should probably add a version as well.
*/
Status Register(const std::string& op_type, std::unique_ptr<RewriteRule> rule);
/** Check if the given op_type has any rules registered for it
@returns true if there are rules registered for this op_type.*/
bool HasRules(const std::string& op_type) const {
return op_to_rules_.find(op_type) != op_to_rules_.cend();
}
/**
Gets the rewrite rules for the given op_type.
@returns a pointer to the vector containing all the rewrite rules registered for op_type if found. nullptr
otherwise.
*/
const std::vector<std::unique_ptr<RewriteRule>>* GetRewriteRules(const std::string& op_type) const {
auto entry = op_to_rules_.find(op_type);
if (entry != op_to_rules_.cend())
return &entry->second;
return nullptr;
}
private:
using RewriteRuleSet = std::unordered_map<std::string, std::vector<std::unique_ptr<RewriteRule>>>;
RewriteRuleSet op_to_rules_;
};
/**
@class TopDownRuleBasedTransformer
This is a rule-based Graph transformer that applies rules by performing top-down passes of the Graph.
*/
class TopDownRuleBasedTransformer : public RuleBasedGraphTransformer {
public:
TopDownRuleBasedTransformer(const std::string& name, const std::string& desc)
: RuleBasedGraphTransformer(name, desc) {}
private:
// Performs a single top-down traversal of the graph and applies all registered rules.
common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
};
} // namespace onnxruntime