Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
ctc_loss_op.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
14 #define DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
15 
16 #include "core/operator.h"
17 
18 namespace dragon {
19 
20 template <class Context>
21 class CTCLossOp final : public Operator<Context> {
22  public:
23  CTCLossOp(const OperatorDef& def, Workspace* ws)
24  : Operator<Context>(def, ws) {
25  LOG(FATAL) << "CTCLoss requires CuDNN support.";
26  }
28 
29  void RunOnDevice() override {}
30 };
31 
32 template <class Context>
33 class CTCLossGradientOp final : public Operator<Context> {
34  public:
35  CTCLossGradientOp(const OperatorDef& def, Workspace* ws)
36  : Operator<Context>(def, ws) {}
38 
39  void RunOnDevice() override;
40  template <typename T> void RunImpl();
41 };
42 
43 #ifdef WITH_CUDNN
44 
45 #if CUDNN_VERSION_MIN(7, 0, 0)
46 
47 template <class Context>
48 class CuDNNCTCLossOp final : public Operator<Context> {
49  public:
50  CuDNNCTCLossOp(const OperatorDef& def, Workspace* ws)
51  : Operator<Context>(def, ws),
52  blank_first_(OpArg<bool>("blank_first", true)),
53  padding_mask_(OpArg<int64_t>("padding_mask", -1)) {
54  CuDNNCreateTensorDesc(&prob_desc_);
55  CuDNNCreateTensorDesc(&grad_desc_);
56  ctc_algo_ = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC;
57  CUDNN_CHECK(cudnnCreateCTCLossDescriptor(&ctc_desc_));
58  }
60 
61  ~CuDNNCTCLossOp() {
62  CuDNNDestroyTensorDesc(&prob_desc_);
63  CuDNNDestroyTensorDesc(&grad_desc_);
64  CUDNN_CHECK(cudnnDestroyCTCLossDescriptor(ctc_desc_));
65  }
66 
67  void Reshape();
68 
69  void RunOnDevice() override;
70  template <typename T> void RunImpl();
71 
72  protected:
73  bool blank_first_;
74  int64_t padding_mask_;
75  size_t workspace_size_;
76  cudnnCTCLossAlgo_t ctc_algo_;
77  cudnnCTCLossDescriptor_t ctc_desc_;
78  cudnnTensorDescriptor_t prob_desc_, grad_desc_;
79  vec32_t packed_labels_, label_lengths_, input_lengths_;
80 };
81 
82 #endif
83 
84 #endif // WITH_CUDNN
85 
86 } // namespace dragon
87 
88 #endif // DRAGON_OPERATORS_LOSS_CTC_LOSS_OP_H_
Definition: ctc_loss_op.h:21
Definition: logging.h:21
Definition: workspace.h:20
#define CUDNN_CHECK(condition)
Definition: cudnn_device.h:34
CTCLossGradientOp(const OperatorDef &def, Workspace *ws)
Definition: ctc_loss_op.h:35
const OperatorDef & def() const
Return the stored def.
Definition: operator.h:114
Workspace * ws() const
Return the parent workspace.
Definition: operator.h:87
void RunImpl()
Definition: ctc_loss_op.cc:8
void RunOnDevice() override
Implement the detailed execution.
Definition: ctc_loss_op.h:29
void CuDNNCreateTensorDesc(cudnnTensorDescriptor_t *desc)
Definition: cudnn_device.h:67
USE_OPERATOR_FUNCTIONS
Definition: ctc_loss_op.h:27
Definition: operator.h:149
#define OpArg
Definition: operator.h:235
#define USE_OPERATOR_FUNCTIONS
Definition: operator.h:261
Definition: ctc_loss_op.h:33
USE_OPERATOR_FUNCTIONS
Definition: ctc_loss_op.h:37
void CuDNNDestroyTensorDesc(cudnnTensorDescriptor_t *desc)
Definition: cudnn_device.h:72
CTCLossOp(const OperatorDef &def, Workspace *ws)
Definition: ctc_loss_op.h:23
void RunOnDevice() override
Implement the detailed execution.
Definition: ctc_loss_op.cc:25
#define LOG(severity)
Definition: logging.h:54
std::vector< int > vec32_t
Definition: types.h:24
Definition: common.h:41