Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
graph.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_CORE_GRAPH_H_
14 #define DRAGON_CORE_GRAPH_H_
15 
16 #include "core/common.h"
17 #include "core/operator.h"
18 
19 namespace dragon {
20 
21 class GraphBase {
22  public:
24  GraphBase(
25  const GraphDef& def,
26  Workspace* ws);
27 
29  virtual ~GraphBase() {}
30 
32  virtual bool Create(
33  const GraphDef& def,
34  Workspace* ws) = 0;
35 
37  virtual bool Run(
38  const string& include,
39  const string& exclude,
40  int stream_id = 0) = 0;
41 
43  string name() const { return name_; }
44 
46  const string& phase() const { return phase_; }
47 
50 
52  const Argument& arg(const string& name) { return *(args_[name]); }
53 
55  const GraphDef& def() const { return def_; }
56 
58  const GraphDef& opt_def() const { return opt_def_; }
59 
61  Workspace* ws() const { return ws_; }
62 
63  protected:
65  string name_, phase_;
66 
69 
72 
74  GraphDef def_, opt_def_;
75 };
76 
77 class Graph : public GraphBase {
78  public:
80  Graph(const GraphDef& def, Workspace* ws);
81 
83  virtual ~Graph() { for (auto* op : ops_) delete op; }
84 
86  bool Create(
87  const GraphDef& def,
88  Workspace* ws) override;
89 
91  bool Run(
92  const string& include,
93  const string& exclude,
94  int stream_id = 0) override;
95 
96  protected:
98  vector<OperatorBase*> ops_;
99 };
100 
103  const GraphDef& def,
104  Workspace* ws);
105 
106 /* Macros */
107 
109  GraphRegistry,
110  GraphBase,
111  const GraphDef&,
112  Workspace*);
113 
114 #define REGISTER_GRAPH(name, ...) \
115  REGISTER_CLASS(GraphRegistry, name, __VA_ARGS__)
116 
117 } // namespace dragon
118 
119 #endif // DRAGON_CORE_GRAPH_H_
string name_
Store the name and running phase.
Definition: graph.h:65
Workspace * ws_
Store the parent workspace.
Definition: graph.h:71
Workspace * ws() const
Return the parent workspace.
Definition: graph.h:61
Definition: workspace.h:20
std::unordered_map< Key, Value > Map
Definition: common.h:54
string phase_
Definition: graph.h:65
bool Create(const GraphDef &def, Workspace *ws) override
Create a graph from the optimized def.
Definition: graph.cc:57
GraphDef def_
Store the def.
Definition: graph.h:74
GraphBase(const GraphDef &def, Workspace *ws)
Default constructor.
Definition: graph.cc:10
vector< OperatorBase * > ops_
Store the internal operators.
Definition: graph.h:98
bool Run(const string &include, const string &exclude, int stream_id=0) override
Run the graph once synchronously.
Definition: graph.cc:122
DECLARE_REGISTRY(GraphRegistry, GraphBase, const GraphDef &, Workspace *)
const Map< std::string, const Argument * > & args()
Return the argument map.
Definition: graph.h:49
virtual bool Create(const GraphDef &def, Workspace *ws)=0
Create a graph from the optimized def.
string name() const
Return the graph name.
Definition: graph.h:43
Definition: graph.h:21
const Argument & arg(const string &name)
Return the specified argument.
Definition: graph.h:52
Definition: graph.h:77
GraphDef opt_def_
Definition: graph.h:74
virtual bool Run(const string &include, const string &exclude, int stream_id=0)=0
Run the graph once synchronously.
GraphBase * NewGraph(const GraphDef &def, Workspace *ws)
Create a graph from the raw def.
Definition: graph.cc:144
const GraphDef & opt_def() const
Return the stored opt def.
Definition: graph.h:58
const string & phase() const
Return the defined running phase.
Definition: graph.h:46
virtual ~Graph()
Default deconstructor.
Definition: graph.h:83
Graph(const GraphDef &def, Workspace *ws)
Default constructor.
Definition: graph.cc:83
const GraphDef & def() const
Return the stored raw def.
Definition: graph.h:55
Map< string, const Argument * > args_
Store the defined arguments.
Definition: graph.h:68
Definition: common.h:41
virtual ~GraphBase()
Default deconstructor.
Definition: graph.h:29