Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
relu_op.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
14 #define DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
15 
16 #include "core/operator.h"
17 
18 namespace dragon {
19 
20 template <class Context>
21 class ReluOp : public Operator<Context> {
22  public:
23  ReluOp(const OperatorDef& def, Workspace* ws)
24  : Operator<Context>(def, ws),
25  slope_(OpArg<float>("slope", 0.f)) {}
27 
28  void RunOnDevice() override;
29  template <typename T> void RunImpl();
30 
31  protected:
32  float slope_;
33 };
34 
35 template <class Context>
36 class ReluGradientOp : public Operator<Context> {
37  public:
38  ReluGradientOp(const OperatorDef& def, Workspace* ws)
39  : Operator<Context>(def, ws),
40  slope_(OpArg<float>("slope", 0.f)) {}
42 
43  void RunOnDevice() override;
44  template <typename T> void RunImpl();
45 
46  protected:
47  float slope_;
48 };
49 
50 #ifdef WITH_CUDNN
51 
52 template <class Context>
53 class CuDNNReluOp final : public ReluOp<Context> {
54 public:
55  CuDNNReluOp(const OperatorDef& def, Workspace* ws)
56  : ReluOp<Context>(def, ws) {
58  CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc_));
59  CUDNN_CHECK(cudnnSetActivationDescriptor(
60  act_desc_,
61  CUDNN_ACTIVATION_RELU,
62  CUDNN_PROPAGATE_NAN, 0
63  ));
64  }
66 
69  CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
70  }
71 
72  void RunOnDevice() override;
73  template <typename T> void RunImpl();
74 
75  protected:
76  cudnnTensorDescriptor_t input_desc_;
77  cudnnActivationDescriptor_t act_desc_;
78 };
79 
80 template <class Context>
81 class CuDNNReluGradientOp final : public ReluGradientOp<Context> {
82  public:
83  CuDNNReluGradientOp(const OperatorDef& def, Workspace* ws)
84  : ReluGradientOp<Context>(def, ws) {
86  CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc_));
87  CUDNN_CHECK(cudnnSetActivationDescriptor(
88  act_desc_,
89  CUDNN_ACTIVATION_RELU,
90  CUDNN_PROPAGATE_NAN, 0
91  ));
92  }
94 
97  CUDNN_CHECK(cudnnDestroyActivationDescriptor(act_desc_));
98  }
99 
100  void RunOnDevice() override;
101  template <typename T> void RunImpl();
102 
103  protected:
104  cudnnTensorDescriptor_t input_desc_;
105  cudnnActivationDescriptor_t act_desc_;
106 };
107 
108 #endif // WITH_CUDNN
109 
110 } // namespace dragon
111 
112 #endif // DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
cudnnActivationDescriptor_t act_desc_
Definition: relu_op.h:105
ReluOp(const OperatorDef &def, Workspace *ws)
Definition: relu_op.h:23
Definition: workspace.h:20
#define CUDNN_CHECK(condition)
Definition: cudnn_device.h:34
void RunImpl()
Definition: cudnn_relu_op.cc:8
cudnnActivationDescriptor_t act_desc_
Definition: relu_op.h:77
const OperatorDef & def() const
Return the stored def.
Definition: operator.h:114
Workspace * ws() const
Return the parent workspace.
Definition: operator.h:87
void RunImpl()
Definition: relu_op.cc:28
void RunOnDevice() override
Implement the detailed execution.
Definition: cudnn_relu_op.cc:80
float slope_
Definition: relu_op.h:32
USE_OPERATOR_FUNCTIONS
Definition: relu_op.h:26
Definition: relu_op.h:53
float slope_
Definition: relu_op.h:47
Definition: relu_op.h:21
void CuDNNCreateTensorDesc(cudnnTensorDescriptor_t *desc)
Definition: cudnn_device.h:67
Definition: relu_op.h:81
USE_OPERATOR_FUNCTIONS
Definition: relu_op.h:65
ReluGradientOp(const OperatorDef &def, Workspace *ws)
Definition: relu_op.h:38
cudnnTensorDescriptor_t input_desc_
Definition: relu_op.h:76
Definition: operator.h:149
#define OpArg
Definition: operator.h:235
~CuDNNReluOp()
Definition: relu_op.h:67
void RunImpl()
Definition: relu_op.cc:8
USE_OPERATOR_FUNCTIONS
Definition: relu_op.h:41
CuDNNReluGradientOp(const OperatorDef &def, Workspace *ws)
Definition: relu_op.h:83
USE_OPERATOR_FUNCTIONS
Definition: relu_op.h:93
CuDNNReluOp(const OperatorDef &def, Workspace *ws)
Definition: relu_op.h:55
void CuDNNDestroyTensorDesc(cudnnTensorDescriptor_t *desc)
Definition: cudnn_device.h:72
void RunOnDevice() override
Implement the detailed execution.
Definition: relu_op.cc:41
void RunOnDevice() override
Implement the detailed execution.
Definition: cudnn_relu_op.cc:35
void RunOnDevice() override
Implement the detailed execution.
Definition: relu_op.cc:20
void RunImpl()
Definition: cudnn_relu_op.cc:48
Definition: relu_op.h:36
Definition: common.h:41
~CuDNNReluGradientOp()
Definition: relu_op.h:95
cudnnTensorDescriptor_t input_desc_
Definition: relu_op.h:104