Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
pad_op.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_OPERATORS_ARRAY_PAD_OP_H_
14 #define DRAGON_OPERATORS_ARRAY_PAD_OP_H_
15 
16 #include "core/operator.h"
17 
18 namespace dragon {
19 
20 template <class Context>
21 class PadOp final : public Operator<Context> {
22  public:
23  PadOp(const OperatorDef& def, Workspace* ws)
24  : Operator<Context>(def, ws),
25  pad_l_(OpArgs<int64_t>("pad_l")),
26  pad_r_(OpArgs<int64_t>("pad_r")),
27  mode_(OpArg<string>("mode", "CONSTANT")),
28  value_(OpArg<float>("value", 0.f)) {
29  if (pad_r_.empty()) {
30  pad_r_ = pad_l_;
31  } else {
32  CHECK_EQ(pad_l_.size(), pad_r_.size())
33  << "\nThe pad_l and pad_r "
34  << "should have the same length.";
35  }
36  }
38 
39  void RunOnDevice() override;
40  template <typename T> void RunImpl();
41  template <typename T> void ConstRunImpl();
42  template <typename T> void ReflectRunImpl();
43  template <typename T> void EdgeRunImpl();
44 
45  protected:
46  float value_;
47  string mode_;
50 };
51 
52 template <class Context>
53 class PadGradientOp final : public Operator<Context> {
54  public:
55  PadGradientOp(const OperatorDef& def, Workspace* ws)
56  : Operator<Context>(def, ws),
57  pad_l_(OpArgs<int64_t>("pad_l")),
58  pad_r_(OpArgs<int64_t>("pad_r")),
59  mode_(OpArg<string>("mode", "CONSTANT")) {
60  if (pad_r_.empty()) {
61  pad_r_ = pad_l_;
62  } else {
63  CHECK_EQ(pad_l_.size(), pad_r_.size())
64  << "\nThe pad_l and pad_r "
65  << "should have the same length.";
66  }
67  }
69 
70  void RunOnDevice() override;
71  template <typename T> void RunImpl();
72  template <typename T> void ConstRunImpl();
73  template <typename T> void ReflectRunImpl();
74  template <typename T> void EdgeRunImpl();
75 
76  protected:
77  string mode_;
80 };
81 
82 } // namespace dragon
83 
84 #endif // DRAGON_OPERATORS_ARRAY_PAD_OP_H_
Tensor Y_strides_
Definition: pad_op.h:79
Definition: workspace.h:20
void RunImpl()
Definition: pad_op.cc:148
void ConstRunImpl()
Definition: pad_op.cc:122
const OperatorDef & def() const
Return the stored def.
Definition: operator.h:114
Workspace * ws() const
Return the parent workspace.
Definition: operator.h:87
void ReflectRunImpl()
Definition: pad_op.cc:138
Tensor X_dims_
Definition: pad_op.h:49
Definition: pad_op.h:53
vec64_t pad_r_
Definition: pad_op.h:78
Definition: tensor.h:21
Tensor X_dims_
Definition: pad_op.h:79
USE_OPERATOR_FUNCTIONS
Definition: pad_op.h:68
vec64_t pad_l_
Definition: pad_op.h:78
#define OpArgs
Definition: operator.h:236
Definition: operator.h:149
#define OpArg
Definition: operator.h:235
void ConstRunImpl()
Definition: pad_op.cc:16
std::vector< int64_t > vec64_t
Definition: types.h:25
void EdgeRunImpl()
Definition: pad_op.cc:48
vec64_t pad_r_
Definition: pad_op.h:48
USE_OPERATOR_FUNCTIONS
Definition: pad_op.h:37
void ReflectRunImpl()
Definition: pad_op.cc:32
void RunOnDevice() override
Implement the detailed execution.
Definition: pad_op.cc:175
Tensor Y_dims_
Definition: pad_op.h:49
void RunImpl()
Definition: pad_op.cc:64
float value_
Definition: pad_op.h:46
Tensor pads_
Definition: pad_op.h:79
vec64_t pad_l_
Definition: pad_op.h:48
Definition: pad_op.h:21
void RunOnDevice() override
Implement the detailed execution.
Definition: pad_op.cc:91
Tensor X_strides_
Definition: pad_op.h:49
string mode_
Definition: pad_op.h:77
PadOp(const OperatorDef &def, Workspace *ws)
Definition: pad_op.h:23
void EdgeRunImpl()
Definition: pad_op.cc:143
Definition: common.h:41
Tensor pads_
Definition: pad_op.h:49
PadGradientOp(const OperatorDef &def, Workspace *ws)
Definition: pad_op.h:55
#define CHECK_EQ(val1, val2)
Definition: logging.h:48
string mode_
Definition: pad_op.h:47