Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
multinomial_op.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_OPERATORS_ARRAY_MULTINOMIAL_OP_H_
14 #define DRAGON_OPERATORS_ARRAY_MULTINOMIAL_OP_H_
15 
16 #include "core/operator.h"
17 
18 namespace dragon {
19 
20 template <class Context>
21 class MultinomialOp final : public Operator<Context> {
22  public:
23  MultinomialOp(const OperatorDef& def, Workspace* ws)
24  : Operator<Context>(def, ws),
25  eps_(OpArg<float>("eps", 0.f)),
26  normalize_(OpArg<int64_t>("normalize", 0)),
27  num_samples_(OpArg<int64_t>("num_samples", 1)) {}
29 
30  void SoftmaxRun();
31 
32  void RunOnDevice() override;
33  template <typename T> void RunImpl();
34 
35  protected:
36  float eps_;
37  int64_t outer_dim_, axis_;
39  unique_ptr<OperatorBase> softmax_op_;
40 };
41 
42 } // namespace dragon
43 
44 #endif // DRAGON_OPERATORS_ARRAY_MULTINOMIAL_OP_H_
Definition: workspace.h:20
int64_t axis_
Definition: multinomial_op.h:37
void RunOnDevice() override
Implement the detailed execution.
Definition: multinomial_op.cc:74
int64_t num_samples_
Definition: multinomial_op.h:38
MultinomialOp(const OperatorDef &def, Workspace *ws)
Definition: multinomial_op.h:23
const OperatorDef & def() const
Return the stored def.
Definition: operator.h:114
Workspace * ws() const
Return the parent workspace.
Definition: operator.h:87
Definition: multinomial_op.h:21
Definition: operator.h:149
#define OpArg
Definition: operator.h:235
float eps_
Definition: multinomial_op.h:36
unique_ptr< OperatorBase > softmax_op_
Definition: multinomial_op.h:39
int64_t normalize_
Definition: multinomial_op.h:38
void SoftmaxRun()
Definition: multinomial_op.cc:9
int64_t outer_dim_
Definition: multinomial_op.h:37
void RunImpl()
Definition: multinomial_op.cc:26
Definition: common.h:41
USE_OPERATOR_FUNCTIONS
Definition: multinomial_op.h:28