Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
concat_op.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_OPERATORS_ARRAY_CONCAT_OP_H_
14 #define DRAGON_OPERATORS_ARRAY_CONCAT_OP_H_
15 
16 #include "core/operator.h"
17 
18 namespace dragon {
19 
20 template <class Context>
21 class ConcatOp : public Operator<Context> {
22  public:
23  ConcatOp(const OperatorDef& def, Workspace* ws)
24  : Operator<Context>(def, ws),
25  axis_(OpArg<int64_t>("axis", 0)) {}
27 
28  void RunOnDevice() override;
29  template <typename T> void RunImpl();
30 
31  protected:
32  int64_t axis_, cat_dim_;
34 };
35 
36 template <class Context>
37 class ConcatGradientOp : public Operator<Context> {
38  public:
39  ConcatGradientOp(const OperatorDef& def, Workspace* ws)
40  : Operator<Context>(def, ws),
41  axis_(OpArg<int64_t>("axis", 0)) {}
43 
44  void RunOnDevice() override;
45  template <typename T> void RunImpl();
46 
47  protected:
48  int64_t axis_, cat_dim_;
50 };
51 
52 } // namespace dragon
53 
54 #endif // DRAGON_OPERATORS_ARRAY_CONCAT_OP_H_
ConcatGradientOp(const OperatorDef &def, Workspace *ws)
Definition: concat_op.h:39
USE_OPERATOR_FUNCTIONS
Definition: concat_op.h:42
Definition: concat_op.h:37
Definition: workspace.h:20
void RunOnDevice() override
Implement the detailed execution.
Definition: concat_op.cc:36
const OperatorDef & def() const
Return the stored def.
Definition: operator.h:114
Workspace * ws() const
Return the parent workspace.
Definition: operator.h:87
int64_t outer_dim_
Definition: concat_op.h:49
int64_t cat_dim_
Definition: concat_op.h:48
Definition: concat_op.h:21
USE_OPERATOR_FUNCTIONS
Definition: concat_op.h:26
int64_t axis_
Definition: concat_op.h:48
Definition: operator.h:149
#define OpArg
Definition: operator.h:235
int64_t inner_dim_
Definition: concat_op.h:49
int64_t cat_dim_
Definition: concat_op.h:32
int64_t outer_dim_
Definition: concat_op.h:33
ConcatOp(const OperatorDef &def, Workspace *ws)
Definition: concat_op.h:23
void RunOnDevice() override
Implement the detailed execution.
Definition: concat_op.cc:89
void RunImpl()
Definition: concat_op.cc:16
void RunImpl()
Definition: concat_op.cc:66
Definition: common.h:41
int64_t axis_
Definition: concat_op.h:32
int64_t inner_dim_
Definition: concat_op.h:33