Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
context.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_CORE_CONTEXT_H_
14 #define DRAGON_CORE_CONTEXT_H_
15 
16 #include "core/common.h"
17 
18 namespace dragon {
19 
20 class CPUContext {
21  public:
23  CPUContext(): random_seed_(3) {}
24 
26  CPUContext(unsigned int random_seed)
27  : random_seed_(random_seed) {}
28 
30  CPUContext(const DeviceOption& option)
31  : random_seed_(option.has_random_seed() ?
32  option.random_seed() : DEFAULT_RNG_SEED) {}
33 
35  virtual ~CPUContext() {}
36 
38  void SwitchToDevice() {}
39 
41  void SwitchToDevice(const int stream_id) {}
42 
45 
47  static void* New(size_t nbytes) {
48  void* data = malloc(nbytes);
49  CHECK(data) << "\nMalloc mem: "
50  << nbytes << " bytes failed.";
51  return data;
52  }
53 
55  static void Memset(
56  size_t nbytes,
57  void* ptr) {
58  memset(ptr, 0, nbytes);
59  }
60 
63  size_t nbytes,
64  void* ptr) {
65  memset(ptr, 0, nbytes);
66  }
67 
69  template<class DstContext, class SrcContext>
70  static void Memcpy(
71  size_t nbytes,
72  void* dst,
73  const void* src) {
74  memcpy(dst, src, nbytes);
75  }
76 
78  template<class DstContext, class SrcContext>
80  size_t nbytes,
81  void* dst,
82  const void* src) {
83  memcpy(dst, src, nbytes);
84  }
85 
87  template<typename T, class DstContext, class SrcContext>
88  void Copy(
89  int n,
90  T* dst,
91  const T* src) {
92  if (dst == src) return;
93  if (std::is_fundamental<T>::value)
94  Memcpy<DstContext, SrcContext>(
95  n * sizeof(T), (void*)dst, (const void*)src);
96  else for (int i = 0; i < n; i++) dst[i] = src[i];
97  }
98 
100  static void Delete(void* data) { free(data); }
101 
103  int device_id() const { return 0; }
104 
106  int stream_id() const { return 0; }
107 
110 
112  std::mt19937* rand_generator() {
113  if (!rand_generator_.get())
114  rand_generator_.reset(new std::mt19937(random_seed_));
115  return rand_generator_.get();
116  }
117 
118  private:
120  unsigned int random_seed_;
121 
123  unique_ptr<std::mt19937> rand_generator_;
124 };
125 
126 #define CPU_FP16_NOT_SUPPORTED \
127  LOG(FATAL) << "FP16 is unsupported for CPUContext.";
128 
129 } // namepsace dragon
130 
131 #endif // DRAGON_CORE_CONTEXT_H_
#define DEFAULT_RNG_SEED
Definition: common.h:75
static void Delete(void *data)
Free the memory.
Definition: context.h:100
int device_id() const
Return the device id.
Definition: context.h:103
void SwitchToDevice(const int stream_id)
Switch to the device with the given stream.
Definition: context.h:41
CPUContext(const DeviceOption &option)
Constructor with the specified device option.
Definition: context.h:30
void set_stream_id(int stream_id)
Set the stream id.
Definition: context.h:109
int stream_id() const
Return the stream id.
Definition: context.h:106
Definition: context.h:20
virtual ~CPUContext()
Deconstructor.
Definition: context.h:35
void SwitchToDevice()
Switch to the device of this context.
Definition: context.h:38
static void * New(size_t nbytes)
Malloc the memory.
Definition: context.h:47
void FinishDeviceCompution()
Synchronize the dispatched operations.
Definition: context.h:44
static void Memcpy(size_t nbytes, void *dst, const void *src)
Copy the memory.
Definition: context.h:70
void MemcpyAsync(size_t nbytes, void *dst, const void *src)
Copy the memory asynchronously.
Definition: context.h:79
static void Memset(size_t nbytes, void *ptr)
Zero-Reset the memory.
Definition: context.h:55
CPUContext(unsigned int random_seed)
Constructor with the specified random seed.
Definition: context.h:26
std::mt19937 * rand_generator()
Return the internal random generator.
Definition: context.h:112
void Copy(int n, T *dst, const T *src)
Copy the memory with given type asynchronously.
Definition: context.h:88
CPUContext()
Default Constructor.
Definition: context.h:23
void MemsetAsync(size_t nbytes, void *ptr)
Zero-Reset the memory asynchronously.
Definition: context.h:62
#define CHECK(condition)
Definition: logging.h:45
Definition: common.h:41