Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
initialize_op.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_
14 #define DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_
15 
16 #include "core/operator.h"
17 #include "utils/filler.h"
18 
19 namespace dragon {
20 
21 template <class Context>
22 class InitializeOp : public Operator<Context> {
23  public:
24  InitializeOp(const OperatorDef& def, Workspace* ws)
25  : Operator<Context>(def, ws),
26  shape_desc_(OpArg<string>("shape", "")) {
27  GET_ARGS_WITH_DESC(int64_t, dims);
28  }
30 
31  void RunOnDevice() override;
32  template <typename T> void RunImpl();
33 
34  protected:
35  string shape_desc_;
36  TensorFillerProto proto_;
37  DECLARE_ARGS_WITH_DESC(int64_t, dims);
38 };
39 
40 template <class Context>
41 class FillOp final : public Operator<Context> {
42  public:
43  FillOp(const OperatorDef& def, Workspace* ws)
44  : Operator<Context>(def, ws),
45  shape_desc_(OpArg<string>("shape", "")),
46  value_(OpArg<float>("value", 0.f)) {
47  GET_ARGS_WITH_DESC(int64_t, dims);
48  }
50 
51  void RunOnDevice() override;
52  template <typename T> void RunImpl();
53 
54  protected:
55  float value_;
56  string shape_desc_;
57  DECLARE_ARGS_WITH_DESC(int64_t, dims);
58 };
59 
60 namespace {
61 
62 template<typename T>
63 struct TypeIdentity { typedef T type; };
64 
65 } // namespace
66 
67 template <class Context>
68 class GivenTensorFillOp final : public Operator<Context> {
69  public:
70  GivenTensorFillOp(const OperatorDef& def, Workspace* ws)
71  : Operator<Context>(def, ws),
72  shape_(OpArgs<int64_t>("shape")) {
73  GET_ARGS_WITH_DESC(int64_t, dims);
74  }
76 
77  void RunOnDevice() override;
78  template <typename T> void RunImpl();
79 
80  template <typename T>
81  void Extract() { ExtractImpl(TypeIdentity<T>()); }
82 
83  template <typename T> void ExtractImpl(TypeIdentity<T>) {
84  auto raw_values = OpArgs<T>("values");
85  auto nelements = (int64_t)raw_values.size();
86  auto nbytes = nelements * sizeof(T);
87  auto* values = values_.Reshape({ nelements })
88  ->template mutable_data<T, CPUContext>();
89  memcpy(values, raw_values.data(), nbytes);
90  }
91 
92  void ExtractImpl(TypeIdentity<float16>) {
93  auto raw_values = OpArgs<float>("values");
94  auto nelements = (int64_t)raw_values.size();
95  auto nbytes = nelements * sizeof(float16);
96  auto* values = values_.Reshape({ nelements })
97  ->template mutable_data<float16, CPUContext>();
98  memcpy(values, raw_values.data(), nbytes);
99  }
100 
101  protected:
103  vector<int64_t> shape_;
104  DECLARE_ARGS_WITH_DESC(int64_t, dims);
105 };
106 
107 template <class Context>
108 class RandomUniformOp final : public InitializeOp<Context> {
109  public:
110  RandomUniformOp(const OperatorDef& def, Workspace* ws)
111  : InitializeOp<Context>(def, ws) {
112  auto low = OpArg<float>("low", -1.f);
113  auto high = OpArg<float>("high", 1.f);
114  this->proto_.set_low(low);
115  this->proto_.set_high(high);
116  this->proto_.set_type("uniform");
117  }
119 };
120 
121 template <class Context>
122 class RandomNormalOp final : public InitializeOp<Context> {
123  public:
124  RandomNormalOp(const OperatorDef& def, Workspace* ws)
125  : InitializeOp<Context>(def, ws) {
126  auto mu = OpArg<float>("mean", 0.f);
127  auto sigma = OpArg<float>("std", 1.f);
128  this->proto_.set_mean(mu);
129  this->proto_.set_std(sigma);
130  this->proto_.set_type("normal");
131  }
133 };
134 
135 template <class Context>
136 class TruncatedNormalOp final : public InitializeOp<Context> {
137  public:
138  TruncatedNormalOp(const OperatorDef& def, Workspace* ws)
139  : InitializeOp<Context>(def, ws) {
140  auto mu = OpArg<float>("mean", 0.f);
141  auto sigma = OpArg<float>("std", 1.f);
142  this->proto_.set_mean(mu);
143  this->proto_.set_std(sigma);
144  this->proto_.set_low(mu - 2 * sigma);
145  this->proto_.set_high(mu + 2 * sigma);
146  this->proto_.set_type("truncated_normal");
147  }
149 };
150 
151 template <class Context>
152 class GlorotUniformOp final : public InitializeOp<Context> {
153  public:
154  GlorotUniformOp(const OperatorDef& def, Workspace* ws)
155  : InitializeOp<Context>(def, ws) {
156  auto scale = OpArg<float>("scale", 3.f);
157  auto mode = OpArg<string>("mode", "fan_in");
158  this->proto_.set_type("xavier");
159  if (mode == "fan_avg") {
160  this->proto_.set_variance_norm(
161  TensorFillerProto_VarianceNorm_FAN_AVG);
162  } else if (mode == "fan_out") {
163  this->proto_.set_variance_norm(
164  TensorFillerProto_VarianceNorm_FAN_OUT);
165  } else {
166  this->proto_.set_variance_norm(
167  TensorFillerProto_VarianceNorm_FAN_IN);
168  }
169  this->proto_.set_scale(scale);
170  }
172 };
173 
174 template <class Context>
175 class GlorotNormalOp final : public InitializeOp<Context> {
176  public:
177  GlorotNormalOp(const OperatorDef& def, Workspace* ws)
178  : InitializeOp<Context>(def, ws) {
179  auto scale = OpArg<float>("scale", 2.f);
180  auto mode = OpArg<string>("mode", "fan_in");
181  this->proto_.set_type("msra");
182  if (mode == "fan_avg") {
183  this->proto_.set_variance_norm(
184  TensorFillerProto_VarianceNorm_FAN_AVG);
185  } else if (mode == "fan_out") {
186  this->proto_.set_variance_norm(
187  TensorFillerProto_VarianceNorm_FAN_OUT);
188  } else {
189  this->proto_.set_variance_norm(
190  TensorFillerProto_VarianceNorm_FAN_IN);
191  }
192  this->proto_.set_scale(scale);
193  }
195 };
196 
197 DEFINE_ARGS_WITH_DESC(int64_t, InitializeOp, dims);
198 DEFINE_ARGS_WITH_DESC(int64_t, FillOp, dims);
200 
201 } // namespace dragon
202 
203 #endif // DRAGON_OPERATORS_MISC_INITIALIZE_OP_H_
RandomUniformOp(const OperatorDef &def, Workspace *ws)
Definition: initialize_op.h:110
Definition: workspace.h:20
void Extract()
Definition: initialize_op.h:81
DECLARE_ARGS_WITH_DESC(int64_t, dims)
USE_OPERATOR_FUNCTIONS
Definition: initialize_op.h:132
USE_OPERATOR_FUNCTIONS
Definition: initialize_op.h:118
const OperatorDef & def() const
Return the stored def.
Definition: operator.h:114
Workspace * ws() const
Return the parent workspace.
Definition: operator.h:87
string shape_desc_
Definition: initialize_op.h:35
Definition: initialize_op.h:152
RandomNormalOp(const OperatorDef &def, Workspace *ws)
Definition: initialize_op.h:124
Definition: initialize_op.h:22
void RunOnDevice() override
Implement the detailed execution.
Definition: initialize_op.cc:45
TruncatedNormalOp(const OperatorDef &def, Workspace *ws)
Definition: initialize_op.h:138
Definition: tensor.h:21
void ExtractImpl(TypeIdentity< float16 >)
Definition: initialize_op.h:92
GlorotNormalOp(const OperatorDef &def, Workspace *ws)
Definition: initialize_op.h:177
InitializeOp(const OperatorDef &def, Workspace *ws)
Definition: initialize_op.h:24
Definition: initialize_op.h:175
USE_OPERATOR_FUNCTIONS
Definition: initialize_op.h:194
FillOp(const OperatorDef &def, Workspace *ws)
Definition: initialize_op.h:43
#define OpArgs
Definition: operator.h:236
Definition: operator.h:149
Definition: initialize_op.h:41
#define OpArg
Definition: operator.h:235
USE_OPERATOR_FUNCTIONS
Definition: initialize_op.h:75
void RunImpl()
Definition: initialize_op.cc:7
USE_OPERATOR_FUNCTIONS
Definition: initialize_op.h:29
string shape_desc_
Definition: initialize_op.h:56
Definition: initialize_op.h:108
GlorotUniformOp(const OperatorDef &def, Workspace *ws)
Definition: initialize_op.h:154
Definition: initialize_op.h:68
DECLARE_ARGS_WITH_DESC(int64_t, dims)
void RunOnDevice() override
Implement the detailed execution.
Definition: initialize_op.cc:104
Tensor * Reshape(const vec64_t &dims)
Reshape to the given dimensions.
Definition: tensor.h:38
GivenTensorFillOp(const OperatorDef &def, Workspace *ws)
Definition: initialize_op.h:70
float value_
Definition: initialize_op.h:55
void ExtractImpl(TypeIdentity< T >)
Definition: initialize_op.h:83
DECLARE_ARGS_WITH_DESC(int64_t, dims)
USE_OPERATOR_FUNCTIONS
Definition: initialize_op.h:148
vector< int64_t > shape_
Definition: initialize_op.h:103
DEFINE_ARGS_WITH_DESC(int64_t, CropOp, starts)
USE_OPERATOR_FUNCTIONS
Definition: initialize_op.h:171
void RunImpl()
Definition: initialize_op.cc:33
void RunImpl()
Definition: initialize_op.cc:87
#define GET_ARGS_WITH_DESC(type, arg)
Definition: operator.h:412
TensorFillerProto proto_
Definition: initialize_op.h:36
Definition: initialize_op.h:122
Definition: common.h:41
void RunOnDevice() override
Implement the detailed execution.
Definition: initialize_op.cc:14
USE_OPERATOR_FUNCTIONS
Definition: initialize_op.h:49
Definition: initialize_op.h:136
Tensor values_
Definition: initialize_op.h:102