Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
smooth_l1_loss_op.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
14 #define DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
15 
16 #include "core/operator.h"
17 
18 namespace dragon {
19 
20 template <class Context>
21 class SmoothL1LossOp final
22  : public Operator<Context> {
23  public:
24  SmoothL1LossOp(const OperatorDef& def, Workspace* ws)
25  : Operator<Context>(def, ws),
26  beta_(OpArg<float>("beta", 1.f)),
27  reduction_(OpArg<string>(
28  "reduction", "BATCH_SIZE")) {}
30 
31  void RunOnDevice() override;
32  template <typename T> void RunImpl();
33 
34  protected:
35  float beta_;
36  string reduction_;
37 };
38 
39 template <class Context>
41  : public Operator<Context> {
42  public:
43  SmoothL1LossGradientOp(const OperatorDef& def, Workspace* ws)
44  : Operator<Context>(def, ws),
45  beta_(OpArg<float>(
46  "beta", 1.f)),
47  reduction_(OpArg<string>(
48  "reduction", "BATCH_SIZE")) {}
50 
51  void RunOnDevice() override;
52  template <typename T> void RunImpl();
53 
54  protected:
55  float beta_;
56  string reduction_;
57 };
58 
59 } // namespace dragon
60 
61 #endif // DRAGON_OPERATORS_LOSS_SMOOTH_L1_LOSS_OP_H_
void RunOnDevice() override
Implement the detailed execution.
Definition: smooth_l1_loss_op.cc:108
Definition: workspace.h:20
Definition: smooth_l1_loss_op.h:40
const OperatorDef & def() const
Return the stored def.
Definition: operator.h:114
Workspace * ws() const
Return the parent workspace.
Definition: operator.h:87
float beta_
Definition: smooth_l1_loss_op.h:55
void RunOnDevice() override
Implement the detailed execution.
Definition: smooth_l1_loss_op.cc:52
USE_OPERATOR_FUNCTIONS
Definition: smooth_l1_loss_op.h:29
Definition: operator.h:149
#define OpArg
Definition: operator.h:235
void RunImpl()
Definition: smooth_l1_loss_op.cc:62
SmoothL1LossOp(const OperatorDef &def, Workspace *ws)
Definition: smooth_l1_loss_op.h:24
string reduction_
Definition: smooth_l1_loss_op.h:36
SmoothL1LossGradientOp(const OperatorDef &def, Workspace *ws)
Definition: smooth_l1_loss_op.h:43
float beta_
Definition: smooth_l1_loss_op.h:35
void RunImpl()
Definition: smooth_l1_loss_op.cc:9
USE_OPERATOR_FUNCTIONS
Definition: smooth_l1_loss_op.h:49
Definition: common.h:41
Definition: smooth_l1_loss_op.h:21
string reduction_
Definition: smooth_l1_loss_op.h:56