Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
workspace.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_CORE_WORKSPACE_H_
14 #define DRAGON_CORE_WORKSPACE_H_
15 
16 #include "core/graph.h"
17 
18 namespace dragon {
19 
20 class Workspace {
21  public:
28 
30  Workspace(const string& name) : name_(name) { Initialize(); }
31 
33  const string& name() { return name_; }
34 
36  vector<string> tensors() const;
37 
39  vector<string> graphs() const;
40 
42  void Initialize();
43 
45  void Clear();
46 
48  void MergeFrom(Workspace*);
49 
51  string GetTensorName(const string&) const;
52 
54  Tensor* TryGetTensor(const string&, bool = true) const;
55 
57  bool HasTensor(const string& name, bool use_remote = true) const {
58  return TryGetTensor(name, use_remote) ? true : false;
59  }
60 
62  Tensor* CreateTensor(const string&);
63 
65  Tensor* GetTensor(const string&, bool = true) const;
66 
68  void ResetTensor(const string&);
69 
70  /* \brief Whether the specified filler is in this workspace */
71  bool HasFiller(const string&, bool = true) const;
72 
74  void CreateFiller(const TensorFillerProto&);
75 
77  const TensorFillerProto* GetFiller(const string&) const;
78 
80  template <class Context>
81  vector<void*> data(const vector<size_t>& segments) {
82  int64_t nbytes = 0;
83  vector<void*> ret(segments.size());
84  for (auto& segment : segments) nbytes += (int64_t)segment;
85  auto* T = CreateTensor("/share/data")->Reshape({ nbytes });
86  ret[0] = T->template mutable_data<uint8_t, Context>();
87  for (int i = 1; i < segments.size(); i++)
88  ret[i] = (uint8_t*)ret[i - 1] + segments[i - 1];
89  return ret;
90  }
91 
93  template <typename T, class Context>
94  vector<T*> data(const vector<int64_t>& segments) {
95  vector<size_t> segments_in_byte;
96  vector<T*> ret(segments.size());
97  for (const auto& e : segments)
98  segments_in_byte.emplace_back(e * sizeof(T));
99  auto ret_in_byte = data<Context>(segments_in_byte);
100  for (int i = 0; i < segments.size(); i++)
101  ret[i] = (T*)ret_in_byte[i];
102  return ret;
103  }
104 
106  OperatorBase* CreateOperator(const OperatorDef&);
107 
109  void RunOperator(const OperatorDef&);
110 
112  void RunOperatorOnce(const OperatorDef&);
113 
115  GraphBase* CreateGraph(const GraphDef&);
116 
118  void RunGraph(
119  const string& graph_name,
120  const string& include,
121  const string& exclude,
122  int stream_id = 0);
123 
124  /* \brief Set an alias for the tensor */
125  bool SetTensorAlias(const string& name, const string& alias);
126 
127  /* \brief Return a unique dummy name within this workspace */
128  string GetDummyName(
129  const string& base_name,
130  const string& suffix,
131  const string& domain = "",
132  const bool zero_based = true);
133 
134  private:
136  string name_;
137 
139  DummyNameMap dummy_name_map_;
140 
142  TensorMap tensor_map_;
143 
145  TensorFillerMap tensor_filler_map_;
146 
148  TensorAliasMap tensor_alias_map_;
149 
151  OperatorMap operator_map_;
152 
154  GraphMap graph_map_;
155 
157  vector<Workspace*> remote_workspaces_;
158 };
159 
160 } // namespace dragon
161 
162 #endif // DRAGON_CORE_WORKSPACE_H_
Map< string, TensorFillerProto > TensorFillerMap
Definition: workspace.h:25
Definition: workspace.h:20
std::unordered_map< Key, Value > Map
Definition: common.h:54
void RunOperator(const OperatorDef &)
Run the specified persistent operator.
Definition: workspace.cc:164
bool SetTensorAlias(const string &name, const string &alias)
Definition: workspace.cc:221
GraphBase * CreateGraph(const GraphDef &)
Create a Graph in this workspace.
Definition: workspace.cc:179
vector< string > graphs() const
Return the name of stored graphs.
Definition: workspace.cc:212
void MergeFrom(Workspace *)
Merge from a external workspace.
Definition: workspace.cc:26
bool HasTensor(const string &name, bool use_remote=true) const
Whether the specified tensor is in this workspace.
Definition: workspace.h:57
void CreateFiller(const TensorFillerProto &)
Create the specified filler.
Definition: workspace.cc:126
Tensor * GetTensor(const string &, bool=true) const
Return the specified tensor.
Definition: workspace.cc:75
void Clear()
Destory all the tensors.
Definition: workspace.cc:19
bool HasFiller(const string &, bool=true) const
Definition: workspace.cc:112
Definition: tensor.h:21
Definition: graph.h:21
void RunOperatorOnce(const OperatorDef &)
Try to run the operator in a adaptive mode.
Definition: workspace.cc:171
Tensor * TryGetTensor(const string &, bool=true) const
Try to serach the specified tensor in this workspace.
Definition: workspace.cc:41
Map< string, Map< string, int64_t > > DummyNameMap
Definition: workspace.h:22
OperatorBase * CreateOperator(const OperatorDef &)
Create a operator in this workspace.
Definition: workspace.cc:152
string GetTensorName(const string &) const
Query the real name of specified tensor.
Definition: workspace.cc:33
Map< string, string > TensorAliasMap
Definition: workspace.h:24
vector< T * > data(const vector< int64_t > &segments)
Create temporal cache segments with the specified type.
Definition: workspace.h:94
vector< string > tensors() const
Return the name of stored tensors.
Definition: workspace.cc:95
Map< string, unique_ptr< GraphBase > > GraphMap
Definition: workspace.h:27
void RunGraph(const string &graph_name, const string &include, const string &exclude, int stream_id=0)
Run the specifed graph by name and rules.
Definition: workspace.cc:199
void Initialize()
Create some internal tensors.
Definition: workspace.cc:9
Tensor * Reshape(const vec64_t &dims)
Reshape to the given dimensions.
Definition: tensor.h:38
vector< void * > data(const vector< size_t > &segments)
Create temporal data segments.
Definition: workspace.h:81
const TensorFillerProto * GetFiller(const string &) const
Return the specified filler.
Definition: workspace.cc:136
Workspace(const string &name)
Constructor.
Definition: workspace.h:30
Map< string, unique_ptr< OperatorBase > > OperatorMap
Definition: workspace.h:26
Map< string, unique_ptr< Tensor > > TensorMap
Definition: workspace.h:23
Tensor * CreateTensor(const string &)
Create the specified tensor.
Definition: workspace.cc:63
Definition: operator.h:31
Definition: common.h:41
void ResetTensor(const string &)
Reset the specified tensor.
Definition: workspace.cc:86
const string & name()
Return the name of this workspace.
Definition: workspace.h:33
string GetDummyName(const string &base_name, const string &suffix, const string &domain="", const bool zero_based=true)
Definition: workspace.cc:233