Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
filler.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_UTILS_FILLER_H_
14 #define DRAGON_UTILS_FILLER_H_
15 
16 #include "core/registry.h"
17 #include "utils/math_functions.h"
18 
19 namespace dragon {
20 
21 template <typename T, class Context>
22 class Filler {
23  public:
24  Filler(const TensorFillerProto& proto)
25  : proto_(proto) {}
26 
27  virtual void Fill(Tensor* X, Context* ctx) = 0;
28 
29  inline TensorFillerProto& proto() { return proto_; }
30 
31  protected:
32  TensorFillerProto proto_;
33 };
34 
35 template <typename T, class Context>
36 class ConstantFiller final : public Filler<T, Context> {
37  public:
38  ConstantFiller(const TensorFillerProto& proto)
39  : Filler<T, Context>(proto) {}
40 
41  void Fill(Tensor* X, Context* ctx) override {
42  math::Set(
43  X->count(),
44  cast::to<T>(proto().value()),
45  X->mutable_data<T, Context>(), ctx
46  );
47  }
48 
49  protected:
51 };
52 
53 template <typename T, class Context>
54 class NormalFiller final : public Filler<T, Context> {
55  public:
56  NormalFiller(const TensorFillerProto& proto)
57  : Filler<T, Context>(proto) {}
58 
59  void Fill(Tensor* X, Context* ctx) override {
61  X->count(),
62  proto().mean(), proto().std(),
63  X->mutable_data<T, Context>(), ctx
64  );
65  }
66 
67  protected:
69 };
70 
71 template <typename T, class Context>
72 class TruncatedNormalFiller final : public Filler<T, Context> {
73  public:
74  TruncatedNormalFiller(const TensorFillerProto& proto)
75  : Filler<T, Context>(proto) {}
76 
77  void Fill(Tensor* X, Context* ctx) override {
78  // It's difficult to implement it on gpu
80  X->count(),
81  proto().mean(), proto().std(),
82  proto().low(), proto().high(),
83  X->mutable_data<T, CPUContext>(), &cctx_
84  );
85  }
86 
87  protected:
90 };
91 
92 template <typename T, class Context>
93 class UniformFiller final : public Filler<T, Context> {
94  public:
95  UniformFiller(const TensorFillerProto& proto)
96  : Filler<T, Context>(proto) {}
97 
98  void Fill(Tensor* X, Context* ctx) override {
100  X->count(),
101  proto().low(), proto().high(),
102  X->mutable_data<T, Context>(), ctx
103  );
104  }
105 
106  protected:
108 };
109 
110 template <typename T, class Context>
111 class XavierFiller final : public Filler<T, Context> {
112  public:
113  XavierFiller(const TensorFillerProto& proto)
114  : Filler<T, Context>(proto) {}
115 
116  void Fill(Tensor* X, Context* ctx) override {
117  auto fan_in = X->count() / X->dim(0);
118  auto fan_out = X->count() / X->dim(1);
119  float n = (float)fan_in, scale = 3.f;
120  if (proto().has_scale()) scale = proto().scale();
121  if (proto().variance_norm() ==
122  TensorFillerProto_VarianceNorm_FAN_AVG) {
123  n = (fan_in + fan_out) / 2.f;
124  } else if (proto().variance_norm() ==
125  TensorFillerProto_VarianceNorm_FAN_OUT) {
126  n = (float)fan_out;
127  }
128  float limit = std::sqrt(scale / n);
130  X->count(),
131  -limit, limit,
132  X->mutable_data<T, Context>(), ctx
133  );
134  }
135 
136  protected:
138 };
139 
140 template <typename T, class Context>
141 class MSRAFiller final : public Filler <T, Context> {
142  public:
143  MSRAFiller(const TensorFillerProto& proto)
144  : Filler<T, Context>(proto) {}
145 
146  void Fill(Tensor* X, Context* ctx) override {
147  auto fan_in = X->count() / X->dim(0);
148  auto fan_out = X->count() / X->dim(1);
149  float n = (float)fan_in, scale = 2.f;
150  if (proto().has_scale()) scale = proto().scale();
151  if (proto().variance_norm() ==
152  TensorFillerProto_VarianceNorm_FAN_AVG) {
153  n = (fan_in + fan_out) / 2.f;
154  } else if (proto().variance_norm() ==
155  TensorFillerProto_VarianceNorm_FAN_OUT) {
156  n = (float)fan_out;
157  }
158  float std = std::sqrt(scale / n);
160  X->count(),
161  0.f, std,
162  X->mutable_data<T, Context>(), ctx
163  );
164  }
165 
166  protected:
168 };
169 
170 template <typename T, class Context>
171 Filler<T, Context>* CreateFiller(const TensorFillerProto& proto) {
172  const string& type = proto.type();
173  if (type == "constant") {
174  return new ConstantFiller<T, Context>(proto);
175  } else if (type == "uniform") {
176  return new UniformFiller<T, Context>(proto);
177  } else if (type == "normal") {
178  return new NormalFiller<T, Context>(proto);
179  } else if (type == "truncated_normal") {
180  return new TruncatedNormalFiller<T, Context>(proto);
181  } else if (type == "xavier" || type == "glorot_uniform") {
182  return new XavierFiller<T, Context>(proto);
183  } else if (type == "msra" || type == "glorot_normal") {
184  return new MSRAFiller<T, Context>(proto);
185  } return new ConstantFiller<T, Context>(proto);
186 }
187 
188 } // namespace dragon
189 
190 #endif // DRAGON_UTILS_FILLER_H_
int64_t count(int64_t start, int64_t end) const
Return the number of elements along the [start, end) axes.
Definition: tensor.h:100
Definition: filler.h:54
CPUContext cctx_
Definition: filler.h:88
void Fill(Tensor *X, Context *ctx) override
Definition: filler.h:59
TensorFillerProto proto_
Definition: filler.h:32
Definition: filler.h:72
Definition: filler.h:141
void Fill(Tensor *X, Context *ctx) override
Definition: filler.h:41
UniformFiller(const TensorFillerProto &proto)
Definition: filler.h:95
void RandomNormal(const int n, const float mu, const float sigma, T *y, Context *ctx)
void RandomTruncatedNormal(const int n, const float mu, const float sigma, const float low, const float high, T *y, Context *ctx)
Definition: tensor.h:21
T * mutable_data()
Get the typed mutable data pointer.
Definition: tensor.h:262
void Fill(Tensor *X, Context *ctx) override
Definition: filler.h:146
NormalFiller(const TensorFillerProto &proto)
Definition: filler.h:56
Definition: context.h:20
TensorFillerProto & proto()
Definition: filler.h:29
TruncatedNormalFiller(const TensorFillerProto &proto)
Definition: filler.h:74
void Fill(Tensor *X, Context *ctx) override
Definition: filler.h:98
Filler(const TensorFillerProto &proto)
Definition: filler.h:24
XavierFiller(const TensorFillerProto &proto)
Definition: filler.h:113
int64_t dim(int64_t i) const
Return the dimension of given axis.
Definition: tensor.h:85
Filler< T, Context > * CreateFiller(const TensorFillerProto &proto)
Definition: filler.h:171
void RandomUniform(const int n, const float low, const float high, T *y, Context *ctx)
Definition: filler.h:93
virtual void Fill(Tensor *X, Context *ctx)=0
void Fill(Tensor *X, Context *ctx) override
Definition: filler.h:116
Definition: filler.h:36
void Set(const int n, const T alpha, T *y, Context *ctx)
Definition: common.h:41
Definition: filler.h:22
MSRAFiller(const TensorFillerProto &proto)
Definition: filler.h:143
ConstantFiller(const TensorFillerProto &proto)
Definition: filler.h:38
void Fill(Tensor *X, Context *ctx) override
Definition: filler.h:77
Definition: filler.h:111