Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
dimension_op.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_OPERATORS_ARRAY_DIMENSION_OP_H_
14 #define DRAGON_OPERATORS_ARRAY_DIMENSION_OP_H_
15 
16 #include "core/operator.h"
17 
18 namespace dragon {
19 
20 /* Base */
21 
22 template <class Context>
23 class DimOpBase : public Operator<Context> {
24  public:
26 
27  void MemorySwitch() override {
28  /* Disable the Memory Activation */
29  }
30 };
31 
32 template <class Context>
33 class DimGradientOpBase : public Operator<Context> {
34  public:
37 
38  void RunOnDevice() override {
39  // Simply copy the dY to dX
40  Y(0)->ReshapeLike(X(0));
41  Y(0)->CopyFrom(X(-1), ctx());
42  }
43 };
44 
45 #define DEFINE_DIMENSION_GRADIENT_OP(name) \
46  template <class Context> \
47  class name##GradientOp final : \
48  public DimGradientOpBase<Context> { \
49  public: \
50  name##GradientOp( \
51  const OperatorDef& def, \
52  Workspace* ws) \
53  : DimGradientOpBase<Context>(def, ws) {} \
54  };
55 
56 /* Reshape */
57 
58 template <class Context>
59 class ReshapeOp final : public DimOpBase<Context> {
60  public:
61  ReshapeOp(const OperatorDef& def, Workspace* ws)
62  : DimOpBase<Context>(def, ws),
63  shape_desc_(OpArg<string>("shape_like", "")) {
64  GET_ARGS_WITH_DESC(int64_t, dims);
65  }
67 
68  void RunOnDevice() override;
69 
70  protected:
71  string shape_desc_;
73  DECLARE_ARGS_WITH_DESC(int64_t, dims);
74 };
75 
77 DEFINE_ARGS_WITH_DESC(int64_t, ReshapeOp, dims);
78 
79 /* Flatten */
80 
81 template <class Context>
82 class FlattenOp final : public DimOpBase<Context> {
83  public:
84  FlattenOp(const OperatorDef& def, Workspace* ws)
85  : DimOpBase<Context>(def, ws),
86  axis_(OpArg<int64_t>("axis", 0)),
87  num_axes_(OpArg<int64_t>("num_axes", -1)),
88  keep_axes_(OpArg<int64_t>("keep_axes", INT_MAX)) {}
90 
91  void RunOnDevice() override;
92 
93  protected:
95 };
96 
98 
99 /* ExpandDims */
100 
101 template <class Context>
102 class ExpandDimsOp final : public DimOpBase<Context> {
103  public:
104  ExpandDimsOp(const OperatorDef& def, Workspace* ws)
105  : DimOpBase<Context>(def, ws),
106  axis_(OpArg<int64_t>("axis", 0)) {}
108 
109  void RunOnDevice() override;
110 
111  protected:
112  int64_t axis_;
113 };
114 
115 DEFINE_DIMENSION_GRADIENT_OP(ExpandDims);
116 
117 /* Squeeze */
118 
119 template <class Context>
120 class SqueezeOp final : public DimOpBase<Context> {
121 public:
122  SqueezeOp(const OperatorDef& def, Workspace* ws)
123  : DimOpBase<Context>(def, ws),
124  axis_(OpArg<int64_t>("axis", INT_MAX)) {}
126 
127  void RunOnDevice() override;
128 
129  protected:
130  int64_t axis_;
131 };
132 
134 
135 } // namespace dragon
136 
137 #endif // DRAGON_OPERATORS_ARRAY_RESHAPE_OP_H_
FlattenOp(const OperatorDef &def, Workspace *ws)
Definition: dimension_op.h:84
SIMPLE_CTOR_DTOR(DimGradientOpBase)
Definition: dimension_op.h:33
string shape_desc_
Definition: dimension_op.h:71
Definition: dimension_op.h:23
USE_OPERATOR_FUNCTIONS
Definition: dimension_op.h:66
Definition: workspace.h:20
int64_t num_axes_
Definition: dimension_op.h:94
Definition: dimension_op.h:59
SIMPLE_CTOR_DTOR(DimOpBase)
int64_t axis_
Definition: dimension_op.h:130
Tensor * CopyFrom(const Tensor &other, Context *ctx)
Copy the contents from the given tensor.
Definition: tensor.h:292
USE_OPERATOR_FUNCTIONS
Definition: dimension_op.h:125
void MemorySwitch() override
Coordinate the context of inputs and outputs.
Definition: dimension_op.h:27
const OperatorDef & def() const
Return the stored def.
Definition: operator.h:114
Definition: dimension_op.h:82
Workspace * ws() const
Return the parent workspace.
Definition: operator.h:87
void RunOnDevice() override
Implement the detailed execution.
Definition: expand_dims_op.cc:15
ReshapeOp(const OperatorDef &def, Workspace *ws)
Definition: dimension_op.h:61
Tensor * ReshapeLike(const Tensor &other)
Reshape the dimensions like the given tensor.
Definition: tensor.h:64
vec64_t req_shape_
Definition: dimension_op.h:72
Tensor & X(int i)
Return the specified input tensor.
Definition: operator.cc:46
Definition: dimension_op.h:102
int64_t axis_
Definition: dimension_op.h:94
USE_OPERATOR_FUNCTIONS
Definition: dimension_op.h:36
Definition: operator.h:149
#define OpArg
Definition: operator.h:235
USE_OPERATOR_FUNCTIONS
Definition: dimension_op.h:107
void RunOnDevice() override
Implement the detailed execution.
Definition: flatten_op.cc:15
std::vector< int64_t > vec64_t
Definition: types.h:25
int64_t axis_
Definition: dimension_op.h:112
INT_MAX
Definition: proposal_op.cc:357
Definition: dimension_op.h:120
DECLARE_ARGS_WITH_DESC(int64_t, dims)
void RunOnDevice() override
Implement the detailed execution.
Definition: reshape_op.cc:7
SqueezeOp(const OperatorDef &def, Workspace *ws)
Definition: dimension_op.h:122
void RunOnDevice() override
Implement the detailed execution.
Definition: squeeze_op.cc:15
ExpandDimsOp(const OperatorDef &def, Workspace *ws)
Definition: dimension_op.h:104
vec64_t new_shape_
Definition: dimension_op.h:72
Context * ctx()
Return the internal context.
Definition: operator.h:199
void RunOnDevice() override
Implement the detailed execution.
Definition: dimension_op.h:38
USE_OPERATOR_FUNCTIONS
Definition: dimension_op.h:89
DEFINE_ARGS_WITH_DESC(int64_t, CropOp, starts)
DEFINE_DIMENSION_GRADIENT_OP(Reshape)
#define GET_ARGS_WITH_DESC(type, arg)
Definition: operator.h:412
Tensor * Y(int i)
Return the specified output tensor.
Definition: operator.cc:55
Definition: common.h:41
int64_t keep_axes_
Definition: dimension_op.h:94