Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
batch_norm_op.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
14 #define DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
15 
16 #include <cfloat>
17 
18 #include "core/operator.h"
19 
20 namespace dragon {
21 
22 template <class Context>
23 class BatchNormOp : public Operator<Context> {
24  public:
25  BatchNormOp(const OperatorDef& def, Workspace* ws)
26  : Operator<Context>(def, ws),
27  axis_(OpArg<int64_t>("axis", -1)),
28  momentum_(OpArg<float>("momentum", 0.9f)),
29  eps_(OpArg<float>("eps", 1e-5f)),
30  use_stats_(OpArg<int64_t>("use_stats", -1)) {}
32 
33  void Reshape();
34 
35  void RunOnDevice() override;
36  template <typename Tx, typename Tp> void TrainingImpl();
37  template <typename Tx, typename Tp> void InferenceImpl();
38 
39  protected:
40  float momentum_, eps_;
41  int64_t axis_, use_stats_, N_, C_, S_;
44 };
45 
46 template <class Context>
47 class BatchNormGradientOp : public Operator<Context> {
48  public:
49  BatchNormGradientOp(const OperatorDef& def, Workspace* ws)
50  : Operator<Context>(def, ws),
51  axis_(OpArg<int64_t>("axis", -1)),
52  eps_(OpArg<float>("eps", 1e-5f)),
53  use_stats_(OpArg<int64_t>("use_stats", -1)) {}
55 
56  void Reshape();
57 
58  void RunOnDevice() override;
59  template <typename Tx, typename Tp> void TrainingImpl();
60  template <typename Tx, typename Tp> void InferenceImpl();
61 
62  protected:
63  float eps_;
64  int64_t N_, C_, S_, NC_, NS_;
67 };
68 
69 #ifdef WITH_CUDNN
70 
71 #if CUDNN_VERSION_MIN(5, 0, 0)
72 
73 template <class Context>
74 class CuDNNBatchNormOp final : public BatchNormOp<Context> {
75  public:
76  CuDNNBatchNormOp(const OperatorDef& def, Workspace* ws)
77  : BatchNormOp<Context>(def, ws),
78  axis_(OpArg<int64_t>("axis", -1)),
79  eps64_(OpArg<float>("eps", 1e-5f)),
80  use_stats_(OpArg<int64_t>("use_stats", -1)) {
81  CuDNNCreateTensorDesc(&bn_desc_);
82  CuDNNCreateTensorDesc(&input_desc_);
83  if (eps64_ <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON)
84  LOG(FATAL) << "Provided epsilon is smaller than "
85  << "CUDNN_BN_MIN_EPSILON. \nSet it to "
86  << "CUDNN_BN_MIN_EPSILON instead.";
87  eps64_ = std::max(eps64_, CUDNN_BN_MIN_EPSILON);
88  }
90 
91  ~CuDNNBatchNormOp() {
92  CuDNNDestroyTensorDesc(&bn_desc_);
93  CuDNNDestroyTensorDesc(&input_desc_);
94  }
95 
96  void Reshape();
97 
98  void RunOnDevice() override;
99  template <typename T> void RunImpl();
100 
101  protected:
102  double eps64_;
103  int64_t axis_, N_, C_;
104  int64_t use_stats_, is_training_, is_recomp_;
105  Tensor* mean_, *var_;
106  cudnnTensorDescriptor_t input_desc_, bn_desc_;
107  cudnnBatchNormMode_t bn_mode_;
108 };
109 
110 template <class Context>
111 class CuDNNBatchNormGradientOp final
112  : public BatchNormGradientOp<Context> {
113  public:
114  CuDNNBatchNormGradientOp(const OperatorDef& def, Workspace* ws)
115  : BatchNormGradientOp<Context>(def, ws),
116  axis_(OpArg<int64_t>("axis", -1)),
117  eps64_(OpArg<float>("eps", 1e-5f)),
118  use_stats_(OpArg<int64_t>("use_stats", -1)) {
119  CuDNNCreateTensorDesc(&bn_desc_);
120  CuDNNCreateTensorDesc(&input_desc_);
121  if (eps64_ <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON)
122  LOG(FATAL) << "Provided epsilon is smaller than "
123  << "CUDNN_BN_MIN_EPSILON. \nSet it to "
124  << "CUDNN_BN_MIN_EPSILON instead.";
125  eps64_ = std::max(eps64_, CUDNN_BN_MIN_EPSILON);
126  }
128 
129  ~CuDNNBatchNormGradientOp() {
130  CuDNNDestroyTensorDesc(&bn_desc_);
131  CuDNNDestroyTensorDesc(&input_desc_);
132  }
133 
134  void Reshape();
135 
136  void RunOnDevice() override;
137  template <typename T> void TrainingImpl();
138  template <typename T> void InferenceImpl();
139 
140  protected:
141  double eps64_;
142  int64_t axis_, N_, C_, S_;
143  int64_t use_stats_, is_training_;
144  Tensor* mean_, *var_;
145  cudnnTensorDescriptor_t input_desc_, bn_desc_;
146  cudnnBatchNormMode_t bn_mode_;
147 };
148 
149 #endif
150 
151 #endif // WITH_CUDNN
152 
153 } // namespace dragon
154 
155 #endif // DRAGON_OPERATORS_NORM_BATCH_NORM_OP_H_
Definition: batch_norm_op.h:23
int64_t C_
Definition: batch_norm_op.h:64
Tensor dbias_
Definition: batch_norm_op.h:66
int64_t C_
Definition: batch_norm_op.h:41
Definition: logging.h:21
bool is_training_
Definition: batch_norm_op.h:43
int64_t N_
Definition: batch_norm_op.h:64
USE_OPERATOR_FUNCTIONS
Definition: batch_norm_op.h:31
float eps_
Definition: batch_norm_op.h:40
Definition: workspace.h:20
int64_t S_
Definition: batch_norm_op.h:41
USE_OPERATOR_FUNCTIONS
Definition: batch_norm_op.h:54
int64_t use_stats_
Definition: batch_norm_op.h:41
float momentum_
Definition: batch_norm_op.h:40
const OperatorDef & def() const
Return the stored def.
Definition: operator.h:114
Workspace * ws() const
Return the parent workspace.
Definition: operator.h:87
Tensor dscale_
Definition: batch_norm_op.h:66
bool is_recomp_
Definition: batch_norm_op.h:43
Tensor bias_
Definition: batch_norm_op.h:42
int64_t axis_
Definition: batch_norm_op.h:41
Tensor * var_
Definition: batch_norm_op.h:42
void CuDNNCreateTensorDesc(cudnnTensorDescriptor_t *desc)
Definition: cudnn_device.h:67
Definition: tensor.h:21
void TrainingImpl()
Definition: batch_norm_op.cc:11
void InferenceImpl()
Definition: batch_norm_op.cc:205
void Reshape()
Definition: batch_norm_op.cc:234
float eps_
Definition: batch_norm_op.h:63
Tensor * mean_
Definition: batch_norm_op.h:66
void RunOnDevice() override
Implement the detailed execution.
Definition: batch_norm_op.cc:166
Definition: operator.h:149
#define OpArg
Definition: operator.h:235
#define USE_OPERATOR_FUNCTIONS
Definition: operator.h:261
BatchNormOp(const OperatorDef &def, Workspace *ws)
Definition: batch_norm_op.h:25
int64_t is_training_
Definition: batch_norm_op.h:65
int64_t S_
Definition: batch_norm_op.h:64
Definition: batch_norm_op.h:47
void CuDNNDestroyTensorDesc(cudnnTensorDescriptor_t *desc)
Definition: cudnn_device.h:72
Tensor * mean_
Definition: batch_norm_op.h:42
void InferenceImpl()
Definition: batch_norm_op.cc:88
int64_t NS_
Definition: batch_norm_op.h:64
int64_t use_stats_
Definition: batch_norm_op.h:65
#define LOG(severity)
Definition: logging.h:54
int64_t N_
Definition: batch_norm_op.h:41
void TrainingImpl()
Definition: batch_norm_op.cc:184
int64_t NC_
Definition: batch_norm_op.h:64
void RunOnDevice() override
Implement the detailed execution.
Definition: batch_norm_op.cc:265
void Reshape()
Definition: batch_norm_op.cc:127
Tensor scale_
Definition: batch_norm_op.h:42
Tensor * var_
Definition: batch_norm_op.h:66
int64_t axis_
Definition: batch_norm_op.h:65
Definition: common.h:41
BatchNormGradientOp(const OperatorDef &def, Workspace *ws)
Definition: batch_norm_op.h:49