Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
onnx_backend.h
Go to the documentation of this file.
1 
17 #ifndef DRAGON_ONNX_ONNX_BACKEND_H_
18 #define DRAGON_ONNX_ONNX_BACKEND_H_
19 
20 #include "core/common.h"
21 #include "proto/onnx.pb.h"
22 
23 #define ONNX_NAMESPACE onnx_dragon
24 
25 namespace dragon {
26 
27 namespace onnx {
28 
29 const int kKnownOpsetVersion = 9;
30 
31 using ONNX_NAMESPACE::AttributeProto;
32 using ONNX_NAMESPACE::GraphProto;
33 using ONNX_NAMESPACE::ModelProto;
34 using ONNX_NAMESPACE::NodeProto;
35 using ONNX_NAMESPACE::TensorProto;
36 using ONNX_NAMESPACE::ValueInfoProto;
37 
40 
42  public:
46  const int opset_version)
47  : value_infos_(value_infos),
48  initializer_(initializer),
49  opset_version_(opset_version) {}
50 
51  const ValueInfoMap& value_infos() const { return value_infos_; }
52  const InitializerMap& initializer() const { return initializer_; }
53 
54  const int opset_version() const { return opset_version_; }
55 
56  private:
57  const ValueInfoMap& value_infos_;
58  const InitializerMap& initializer_;
59  const int opset_version_;
60 };
61 
62 typedef struct {
63  vector<OperatorDef> ops;
64 
65  OperatorDef* AddOp() {
66  ops.emplace_back(OperatorDef());
67  return &ops.back();
68  }
69 
70  OperatorDef* GetOp(int index) {
71  CHECK_LT(index, ops.size());
72  return &ops[index];
73  }
74 
75 } ONNXImporterReturns;
76 
78  public:
79  ONNXAttributes(const NodeProto& node);
80 
81  bool HasAttribute(const string& key) const {
82  return onnx_attrs_.count(key) > 0;
83  }
84 
85  AttributeProto* AddRewrittenAttribute(const string& key) {
86  auto tmp = rewritten_onnx_attrs_.emplace(key, AttributeProto());
87  auto& attr = tmp.first->second;
88  attr.set_name(key);
89  return &attr;
90  }
91 
92  google::protobuf::RepeatedPtrField<Argument> AttrToArg(
93  std::function<string(const string&)> mapper) const;
94 
95  template <typename T>
96  T get(const string& key) const;
97 
98  template <typename T>
99  T get(const string& key, const T& default_value) const {
100  if (onnx_attrs_.count(key)) {
101  return get<T>(key);
102  } else {
103  return default_value;
104  }
105  }
106 
107  const AttributeProto* remove(const string& key) {
108  const AttributeProto* result = nullptr;
109  auto iter = onnx_attrs_.find(key);
110  if (iter != onnx_attrs_.end()) {
111  result = iter->second;
112  onnx_attrs_.erase(iter);
113  }
114  return result;
115  }
116 
117  private:
119  Map<string, AttributeProto> rewritten_onnx_attrs_;
120 };
121 
122 template <> int64_t ONNXAttributes::get(const string& key) const;
123 template <> float ONNXAttributes::get(const string& key) const;
124 
125 template <> google::protobuf::RepeatedPtrField<string>
126  ONNXAttributes::get(const string& key) const;
127 
128 template <> google::protobuf::RepeatedField<google::protobuf::int64>
129  ONNXAttributes::get(const string& key) const;
130 
131 template <> google::protobuf::RepeatedField<float>
132  ONNXAttributes::get(const string& key) const;
133 
134 template <> const TensorProto* ONNXAttributes::get(const std::string& key) const;
135 
136 struct ONNXNode {
137  ONNXNode(const NodeProto& node_in)
138  : node(node_in), attributes(node_in) {}
139 
140  const NodeProto& node;
142 };
143 
144 class ONNXBackend {
145  public:
146  void BuildTensorFillOp(
147  const TensorProto& onnx_tensor,
148  OperatorDef* op_def);
149 
151  ONNXNode* onnx_node,
152  const ConversionContext& ctx);
153 
155  ONNXNode* onnx_node,
156  const ConversionContext& ctx);
157 
159  ONNXNode* onnx_node,
160  const ConversionContext& ctx);
161 
163  ONNXNode* onnx_node,
164  const ConversionContext& ctx);
165 
167  ONNXNode* onnx_node,
168  const ConversionContext& ctx);
169 
171  ONNXNode* onnx_node,
172  const ConversionContext& ctx);
173 
175  ONNXNode* onnx_node,
176  const ConversionContext& ctx);
177 
179  ONNXNode* onnx_node,
180  const ConversionContext& ctx);
181 
183  ONNXNode* onnx_node,
184  const ConversionContext& ctx);
185 
187  ONNXNode* onnx_node,
188  const ConversionContext& ctx);
189 
191  ONNXNode* onnx_node,
192  const ConversionContext& ctx);
193 
194  void Prepare(
195  const string& onnx_model_path,
196  GraphDef* init_graph,
197  GraphDef* pred_graph);
198 
200  const TensorProto& onnx_tensor,
201  Argument* dtype,
202  Argument* values);
203 
205  const ModelProto& init_model,
206  const ModelProto& pred_model,
207  const ConversionContext& ctx,
208  ONNXNode* onnx_node);
209 
210  void ONNXToDragon(
211  const ModelProto& onnx_model,
212  const int opset_version,
213  const bool include_initializers,
214  GraphDef* init_graph,
215  GraphDef* pred_graph);
216 
219 
220  const Map<string, string>& get_renamed_nodes() const;
222 
223  const Map<string, string>& get_renamed_attrs() const;
225 };
226 
227 } // namespace onnx
228 
229 } // namespace dragon
230 
231 #endif // DRAGON_ONNX_ONNX_BACKEND_H_
OperatorDef * AddOp()
Definition: onnx_backend.h:65
ONNXImporterReturns LpNormNodeImporter(ONNXNode *onnx_node, const ConversionContext &ctx)
Definition: onnx_importer.cc:323
void ONNXToDragon(const ModelProto &onnx_model, const int opset_version, const bool include_initializers, GraphDef *init_graph, GraphDef *pred_graph)
Definition: onnx_backend.cc:51
ONNXImporterReturns ONNXNodeToOps(const ModelProto &init_model, const ModelProto &pred_model, const ConversionContext &ctx, ONNXNode *onnx_node)
Definition: onnx_backend.cc:125
const int kKnownOpsetVersion
Definition: onnx_backend.h:29
std::unordered_map< Key, Value > Map
Definition: common.h:54
ONNXImporterReturns ReshapeNodeImporter(ONNXNode *onnx_node, const ConversionContext &ctx)
Definition: onnx_importer.cc:195
const ValueInfoMap & value_infos() const
Definition: onnx_backend.h:51
Definition: onnx_backend.h:77
ONNXAttributes attributes
Definition: onnx_backend.h:141
ONNXImporterReturns ATenNodeImporter(ONNXNode *onnx_node, const ConversionContext &ctx)
Definition: onnx_importer.cc:230
Definition: onnx_backend.h:41
Map< string, const TensorProto * > InitializerMap
Definition: onnx_backend.h:39
const Map< string, string > & get_renamed_nodes() const
Definition: onnx_backend.cc:147
ONNXImporterReturns ConvPoolNodeImporter(ONNXNode *onnx_node, const ConversionContext &ctx)
Definition: onnx_importer.cc:52
ONNXImporterReturns GemmNodeImporter(ONNXNode *onnx_node, const ConversionContext &ctx)
Definition: onnx_importer.cc:87
const Map< string, SpecialNodeConverter > & get_special_nodes() const
Definition: onnx_backend.cc:163
OperatorDef * GetOp(int index)
Definition: onnx_backend.h:70
#define CHECK_LT(val1, val2)
Definition: logging.h:52
ONNXImporterReturns MaxRoiPoolNodeImporter(ONNXNode *onnx_node, const ConversionContext &ctx)
Definition: onnx_importer.cc:252
ONNXImporterReturns ArgReduceNodeImporter(ONNXNode *onnx_node, const ConversionContext &ctx)
Definition: onnx_importer.cc:343
Definition: onnx_backend.h:62
const Map< string, Map< string, string > > & get_node_renamed_attrs() const
Definition: onnx_backend.cc:184
Definition: onnx_backend.h:144
ONNXImporterReturns CastNodeImporter(ONNXNode *onnx_node, const ConversionContext &ctx)
Definition: onnx_importer.cc:269
const int opset_version() const
Definition: onnx_backend.h:54
const Map< string, string > & get_renamed_attrs() const
Definition: onnx_backend.cc:192
ONNXImporterReturns CommonONNXNodeImporter(ONNXNode *onnx_node, const ConversionContext &ctx)
Definition: onnx_importer.cc:18
vector< OperatorDef > ops
Definition: onnx_backend.h:63
void BuildTensorFillOp(const TensorProto &onnx_tensor, OperatorDef *op_def)
Definition: onnx_initializer.cc:150
ONNXAttributes(const NodeProto &node)
Definition: onnx_attibute.cc:7
ONNXNode(const NodeProto &node_in)
Definition: onnx_backend.h:137
T get(const string &key) const
const InitializerMap & initializer() const
Definition: onnx_backend.h:52
ONNXImporterReturns UpsampleNodeImporter(ONNXNode *onnx_node, const ConversionContext &ctx)
Definition: onnx_importer.cc:141
ONNXImporterReturns BatchNormNodeImporter(ONNXNode *onnx_node, const ConversionContext &ctx)
Definition: onnx_importer.cc:110
google::protobuf::RepeatedPtrField< Argument > AttrToArg(std::function< string(const string &)> mapper) const
Definition: onnx_attibute.cc:113
const NodeProto & node
Definition: onnx_backend.h:140
T get(const string &key, const T &default_value) const
Definition: onnx_backend.h:99
ConversionContext(const ValueInfoMap &value_infos, const InitializerMap &initializer, const int opset_version)
Definition: onnx_backend.h:43
Map< string, ValueInfoProto > ValueInfoMap
Definition: onnx_backend.h:38
const AttributeProto * remove(const string &key)
Definition: onnx_backend.h:107
void ONNXTensorToArgument(const TensorProto &onnx_tensor, Argument *dtype, Argument *values)
Definition: onnx_initializer.cc:77
Definition: common.h:41
bool HasAttribute(const string &key) const
Definition: onnx_backend.h:81
Definition: onnx_backend.h:136
AttributeProto * AddRewrittenAttribute(const string &key)
Definition: onnx_backend.h:85
void Prepare(const string &onnx_model_path, GraphDef *init_graph, GraphDef *pred_graph)
Definition: onnx_backend.cc:9
ONNXImporterReturns(ONNXBackend::*)(ONNXNode *, const ConversionContext &) SpecialNodeConverter
Definition: onnx_backend.h:218