Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
cudnn_device.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_UTILS_CUDNN_DEVICE_H_
14 #define DRAGON_UTILS_CUDNN_DEVICE_H_
15 
16 #ifdef WITH_CUDNN
17 
18 #include <stdint.h>
19 #include <vector>
20 #include <cudnn.h>
21 
22 #include "core/types.h"
23 
24 namespace dragon {
25 
26 class Tensor;
27 
28 #define CUDNN_VERSION_MIN(major, minor, patch) \
29  (CUDNN_VERSION >= (major * 1000 + minor * 100 + patch))
30 
31 #define CUDNN_VERSION_MAX(major, minor, patch) \
32  (CUDNN_VERSION < (major * 1000 + minor * 100 + patch))
33 
34 #define CUDNN_CHECK(condition) \
35  do { \
36  cudnnStatus_t status = condition; \
37  CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "\n" \
38  << cudnnGetErrorString(status); \
39  } while (0)
40 
41 template <typename T> class CuDNNType;
42 
43 template<> class CuDNNType<float> {
44  public:
45  static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
46  static float oneval, zeroval;
47  static const void *one, *zero;
48  typedef float BNParamType;
49 };
50 
51 template<> class CuDNNType<double> {
52  public:
53  static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
54  static double oneval, zeroval;
55  static const void *one, *zero;
56  typedef double BNParamType;
57 };
58 
59 template<> class CuDNNType<float16> {
60  public:
61  static const cudnnDataType_t type = CUDNN_DATA_HALF;
62  static float oneval, zeroval;
63  static const void *one, *zero;
64  typedef float BNParamType;
65 };
66 
68  cudnnTensorDescriptor_t* desc) {
69  CUDNN_CHECK(cudnnCreateTensorDescriptor(desc));
70 }
71 
73  cudnnTensorDescriptor_t* desc) {
74  CUDNN_CHECK(cudnnDestroyTensorDescriptor(*desc));
75 }
76 
77 template <typename T>
79  cudnnTensorDescriptor_t* desc,
80  Tensor* tensor);
81 
82 template <typename T>
84  cudnnTensorDescriptor_t* desc,
85  const string& data_format,
86  Tensor* tensor);
87 
88 template <typename T>
90  cudnnTensorDescriptor_t* desc,
91  const string& data_format,
92  Tensor* tensor);
93 
94 template <typename T>
96  cudnnTensorDescriptor_t* desc,
97  const string& data_format,
98  Tensor* tensor);
99 
100 template <typename T>
101 void CuDNNSetTensorDesc(
102  cudnnTensorDescriptor_t* desc,
103  const vec64_t& dims);
104 
105 template <typename T>
107  cudnnTensorDescriptor_t* desc,
108  const string& data_format,
109  const vec64_t& dims);
110 
111 template <typename T>
113  cudnnTensorDescriptor_t* desc,
114  const string& data_format,
115  const vec64_t& dims,
116  const int64_t group);
117 
118 template <typename T>
120  cudnnTensorDescriptor_t* desc,
121  const string& data_format,
122  const vec64_t& dims);
123 
124 template <typename T>
126  cudnnTensorDescriptor_t* desc,
127  const string& data_format,
128  const vec64_t& dims);
129 
130 template <typename T>
131 void CuDNNSetTensorDesc(
132  cudnnTensorDescriptor_t* desc,
133  const vec64_t& dims,
134  const vec64_t& strides);
135 
136 } // namespace dragon
137 
138 #endif // WITH_CUDNN
139 
140 #endif // DRAGON_UTILS_CUDNN_DEVICE_H_
static float zeroval
Definition: cudnn_device.h:46
#define CUDNN_CHECK(condition)
Definition: cudnn_device.h:34
void CuDNNSetTensor4dDesc(cudnnTensorDescriptor_t *desc, const string &data_format, Tensor *tensor)
Definition: cudnn_device.cc:196
void CuDNNSetTensor3dDesc(cudnnTensorDescriptor_t *desc, const string &data_format, Tensor *tensor)
Definition: cudnn_device.cc:220
float BNParamType
Definition: cudnn_device.h:64
void CuDNNCreateTensorDesc(cudnnTensorDescriptor_t *desc)
Definition: cudnn_device.h:67
void CuDNNSetTensorDesc(cudnnTensorDescriptor_t *desc, Tensor *tensor)
Definition: cudnn_device.cc:182
void CuDNNSetTensor5dDesc(cudnnTensorDescriptor_t *desc, const string &data_format, Tensor *tensor)
Definition: cudnn_device.cc:208
static const void * zero
Definition: cudnn_device.h:55
std::vector< int64_t > vec64_t
Definition: types.h:25
static const void * zero
Definition: cudnn_device.h:47
static const void * zero
Definition: cudnn_device.h:63
float BNParamType
Definition: cudnn_device.h:48
static float zeroval
Definition: cudnn_device.h:62
void CuDNNDestroyTensorDesc(cudnnTensorDescriptor_t *desc)
Definition: cudnn_device.h:72
static double zeroval
Definition: cudnn_device.h:54
double BNParamType
Definition: cudnn_device.h:56
Definition: cudnn_device.h:41
Definition: common.h:41
void CuDNNSetTensor4dDescWithGroup(cudnnTensorDescriptor_t *desc, const string &data_format, const vec64_t &dims, const int64_t group)
Definition: cudnn_device.cc:83