Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
matmul_op.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
14 #define DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
15 
16 #include "core/operator.h"
17 
18 namespace dragon {
19 
20 template <class Context>
21 class MatmulOp final : public Operator<Context> {
22  public:
23  MatmulOp(const OperatorDef& def, Workspace* ws)
24  : Operator<Context>(def, ws),
25  transA_(OpArg<bool>("transA", false)),
26  transB_(OpArg<bool>("transB", false)) {}
28 
29  void RunOnDevice() override;
30  template <typename T> void RunImpl();
31 
32  protected:
33  int64_t batch_size_;
34  int64_t transA_, transB_;
35  int64_t M_, K1_, K2_, N_;
36  int64_t M1_, N1_, M2_, N2_;
38 };
39 
40 template <class Context>
41 class MatmulGradientOp final : public Operator<Context> {
42  public:
43  MatmulGradientOp(const OperatorDef& def, Workspace* ws)
44  : Operator<Context>(def, ws),
45  transA_(OpArg<bool>("transA", false)),
46  transB_(OpArg<bool>("transB", false)) {}
48 
49  void RunOnDevice() override;
50  template <typename T> void RunImpl();
51 
52  protected:
53  int64_t batch_size_;
54  int64_t transA_, transB_;
55  int64_t M_, K1_, K2_, N_;
56  int64_t M1_, N1_, M2_, N2_;
58 };
59 
60 } // namespace dragon
61 
62 #endif // DRAGON_OPERATORS_ARITHMETIC_MATMUL_OP_H_
int64_t B_stride_
Definition: matmul_op.h:37
int64_t Y_stride_
Definition: matmul_op.h:37
Definition: workspace.h:20
int64_t batch_size_
Definition: matmul_op.h:53
int64_t K1_
Definition: matmul_op.h:55
const OperatorDef & def() const
Return the stored def.
Definition: operator.h:114
Workspace * ws() const
Return the parent workspace.
Definition: operator.h:87
int64_t N_
Definition: matmul_op.h:35
void RunOnDevice() override
Implement the detailed execution.
Definition: matmul_op.cc:27
int64_t N1_
Definition: matmul_op.h:36
int64_t transA_
Definition: matmul_op.h:54
int64_t K2_
Definition: matmul_op.h:35
int64_t transB_
Definition: matmul_op.h:34
void RunImpl()
Definition: matmul_op.cc:74
int64_t N2_
Definition: matmul_op.h:56
void RunOnDevice() override
Implement the detailed execution.
Definition: matmul_op.cc:139
Definition: operator.h:149
int64_t N1_
Definition: matmul_op.h:56
#define OpArg
Definition: operator.h:235
int64_t N_
Definition: matmul_op.h:55
int64_t A_stride_
Definition: matmul_op.h:37
int64_t M2_
Definition: matmul_op.h:56
int64_t M1_
Definition: matmul_op.h:56
MatmulGradientOp(const OperatorDef &def, Workspace *ws)
Definition: matmul_op.h:43
int64_t N2_
Definition: matmul_op.h:36
int64_t A_stride_
Definition: matmul_op.h:57
int64_t B_stride_
Definition: matmul_op.h:57
int64_t K2_
Definition: matmul_op.h:55
int64_t M_
Definition: matmul_op.h:55
int64_t transA_
Definition: matmul_op.h:34
int64_t batch_size_
Definition: matmul_op.h:33
MatmulOp(const OperatorDef &def, Workspace *ws)
Definition: matmul_op.h:23
USE_OPERATOR_FUNCTIONS
Definition: matmul_op.h:47
int64_t Y_stride_
Definition: matmul_op.h:57
int64_t M_
Definition: matmul_op.h:35
int64_t M2_
Definition: matmul_op.h:36
int64_t M1_
Definition: matmul_op.h:36
void RunImpl()
Definition: matmul_op.cc:7
Definition: matmul_op.h:21
USE_OPERATOR_FUNCTIONS
Definition: matmul_op.h:27
Definition: matmul_op.h:41
Definition: common.h:41
int64_t K1_
Definition: matmul_op.h:35
int64_t transB_
Definition: matmul_op.h:54