Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
math_functions.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_UTILS_MATH_FUNCTIONS_H_
14 #define DRAGON_UTILS_MATH_FUNCTIONS_H_
15 
16 #include <cstdint>
17 #include <climits>
18 
19 #include "proto/dragon.pb.h"
20 
21 namespace dragon {
22 
23 // We still follow the CBLAS Transpose custom
24 typedef enum CBLAS_TRANSPOSE {
28 
29 namespace math {
30 
41 template <typename T, class Context>
42 void Copy(
43  const int n,
44  const T* x,
45  T* y,
46  Context* ctx);
47 
48 template <typename T, class Context>
49 void Exp(
50  const int n,
51  const T* x,
52  T* y,
53  Context* ctx);
54 
55 template <typename T, class Context>
56 void Log(
57  const int n,
58  const T* x,
59  T* y,
60  Context* ctx);
61 
62 template <typename T, class Context>
63 void Inv(
64  const int n,
65  const T* x,
66  T* y,
67  Context* ctx);
68 
69 template <typename T, class Context>
70 void Sqrt(
71  const int n,
72  const T* x,
73  T* y,
74  Context* ctx);
75 
76 template <typename T, class Context>
77 void RSqrt(
78  const int n,
79  const T* x,
80  T* y,
81  Context* ctx);
82 
83 template <typename T, class Context>
84 void Square(
85  const int n,
86  const T* x,
87  T* y,
88  Context* ctx);
89 
100 template <typename T, class Context>
101 void Set(
102  const int n,
103  const T alpha,
104  T* y,
105  Context* ctx);
106 
107 template <typename T, class Context>
108 void BroadcastSet(
109  const int rows,
110  const int cols,
111  const int type,
112  const T* x,
113  T* y,
114  Context* ctx);
115 
116 template <typename T, class Context>
117 void Pow(
118  const int n,
119  const float exp,
120  const T* x,
121  T* y,
122  Context* ctx);
123 
124 template <typename T, class Context>
125 void Scale(
126  const int n,
127  const float alpha,
128  const T* x,
129  T* y,
130  Context* ctx);
131 
132 template<typename T, class Context>
133 void Axpy(
134  const int n,
135  const float alpha,
136  const T* x,
137  T* y,
138  Context* ctx);
139 
140 template<typename T, class Context>
141 void Axpby(
142  const int n,
143  const float alpha,
144  const T* x,
145  const float beta,
146  T* y,
147  Context* ctx);
148 
149 template<typename T, class Context>
150 void AddScalar(
151  const int n,
152  const float alpha,
153  T* y,
154  Context* ctx);
155 
166 template<typename T, class Context>
167 void InvStd(
168  const int n,
169  const float eps,
170  const T* x,
171  T* y,
172  Context* ctx);
173 
174 template<typename T, class Context>
175 void Sum(
176  const int n,
177  const float alpha,
178  const T* x,
179  T* y,
180  Context* ctx);
181 
182 template<typename T, class Context>
183 T Sum(
184  const int n,
185  const float alpha,
186  const T* x,
187  Context* ctx);
188 
189 template<typename T, class Context>
190 T ASum(
191  const int n,
192  const T* x,
193  Context* ctx);
194 
205 template <typename T, class Context>
206 void Add(
207  const int n,
208  const T* a,
209  const T* b,
210  T* y,
211  Context* ctx);
212 
213 template <typename T, class Context>
214 void Sub(
215  const int n,
216  const T* a,
217  const T* b,
218  T* y,
219  Context* ctx);
220 
221 template <typename T, class Context>
222 void Mul(
223  const int n,
224  const T* a,
225  const T* b,
226  T* y,
227  Context* ctx);
228 
229 template <typename T, class Context>
230 void Div(
231  const int n,
232  const T* a,
233  const T* b,
234  T* y,
235  Context* ctx);
236 
237 template <typename T, class Context>
238 void Dot(
239  const int n,
240  const T* a,
241  const T* b,
242  T* y,
243  Context* ctx);
244 
255 template <typename T, class Context>
256 void BroadcastAdd(
257  const int rows,
258  const int cols,
259  const int type,
260  const T* a,
261  const T* b,
262  T* y,
263  Context* ctx);
264 
265 template <typename T, class Context>
266 void BroadcastSub(
267  const int rows,
268  const int cols,
269  const int type,
270  const T* a,
271  const T* b,
272  T* y,
273  Context* ctx);
274 
275 template <typename T, class Context>
276 void BroadcastMul(
277  const int rows,
278  const int cols,
279  const int type,
280  const T* a,
281  const T* b,
282  T* y,
283  Context* ctx);
284 
285 template <typename T, class Context>
286 void BroadcastDiv(
287  const int rows,
288  const int cols,
289  const int type,
290  const T* a,
291  const T* b,
292  T* y,
293  Context* ctx);
294 
305 template <typename T, class Context>
306 void Gemm(
307  const CBLAS_TRANSPOSE TransA,
308  const CBLAS_TRANSPOSE TransB,
309  const int M,
310  const int N,
311  const int K,
312  const float alpha,
313  const T* A,
314  const T* B,
315  const float beta,
316  T* C,
317  Context* ctx,
318  TensorProto_DataType math_type = TensorProto_DataType_FLOAT);
319 
320 template<typename T, class Context>
321 void Gemv(
322  const CBLAS_TRANSPOSE TransA,
323  const int M,
324  const int N,
325  const float alpha,
326  const T* A,
327  const T* x,
328  const float beta,
329  T* y,
330  Context* ctx,
331  TensorProto_DataType math_type = TensorProto_DataType_FLOAT);
332 
343 template <typename T, class Context>
344 void RandomUniform(
345  const int n,
346  const float low,
347  const float high,
348  T* y,
349  Context* ctx);
350 
351 template <typename T, class Context>
352 void RandomNormal(
353  const int n,
354  const float mu,
355  const float sigma,
356  T* y,
357  Context* ctx);
358 
359 template <typename T, class Context>
361  const int n,
362  const float mu,
363  const float sigma,
364  const float low,
365  const float high,
366  T* y,
367  Context* ctx);
368 
369 template <typename T, class Context>
370 void RandomBernoulli(
371  const int n,
372  const float p,
373  T* y,
374  Context* ctx);
375 
376 } // namespace math
377 
378 } // namespace dragon
379 
380 #endif // DRAGON_UTILS_MATH_FUNCTIONS_H_
void RandomBernoulli(const int n, const float p, T *y, Context *ctx)
void RSqrt(const int n, const T *x, T *y, Context *ctx)
void Sub(const int n, const T *a, const T *b, T *y, Context *ctx)
Definition: math_functions.h:25
void Sqrt(const int n, const T *x, T *y, Context *ctx)
void Mul(const int n, const T *a, const T *b, T *y, Context *ctx)
void InvStd(const int n, const float eps, const T *x, T *y, Context *ctx)
void Pow(const int n, const float exp, const T *x, T *y, Context *ctx)
void BroadcastDiv(const int rows, const int cols, const int type, const T *a, const T *b, T *y, Context *ctx)
void RandomNormal(const int n, const float mu, const float sigma, T *y, Context *ctx)
void Exp(const int n, const T *x, T *y, Context *ctx)
void RandomTruncatedNormal(const int n, const float mu, const float sigma, const float low, const float high, T *y, Context *ctx)
void Div(const int n, const T *a, const T *b, T *y, Context *ctx)
void AddScalar(const int n, const float alpha, T *y, Context *ctx)
void Add(const int n, const T *a, const T *b, T *y, Context *ctx)
void Scale(const int n, const float alpha, const T *x, T *y, Context *ctx)
Definition: math_functions.h:26
T ASum(const int n, const T *x, Context *ctx)
void BroadcastSub(const int rows, const int cols, const int type, const T *a, const T *b, T *y, Context *ctx)
CBLAS_TRANSPOSE
Definition: math_functions.h:24
void Axpy(const int n, const float alpha, const T *x, T *y, Context *ctx)
void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, const float alpha, const T *A, const T *B, const float beta, T *C, Context *ctx, TensorProto_DataType math_type=TensorProto_DataType_FLOAT)
void Axpby(const int n, const float alpha, const T *x, const float beta, T *y, Context *ctx)
void Dot(const int n, const T *a, const T *b, T *y, Context *ctx)
void Inv(const int n, const T *x, T *y, Context *ctx)
void Gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N, const float alpha, const T *A, const T *x, const float beta, T *y, Context *ctx, TensorProto_DataType math_type=TensorProto_DataType_FLOAT)
void Square(const int n, const T *x, T *y, Context *ctx)
void Sum(const int n, const float alpha, const T *x, T *y, Context *ctx)
void RandomUniform(const int n, const float low, const float high, T *y, Context *ctx)
void BroadcastAdd(const int rows, const int cols, const int type, const T *a, const T *b, T *y, Context *ctx)
void Copy(const int n, const T *x, T *y, Context *ctx)
void Set(const int n, const T alpha, T *y, Context *ctx)
void BroadcastSet(const int rows, const int cols, const int type, const T *x, T *y, Context *ctx)
Definition: common.h:41
void BroadcastMul(const int rows, const int cols, const int type, const T *a, const T *b, T *y, Context *ctx)
void Log(const int n, const T *x, T *y, Context *ctx)