GumbelSoftmax¶
- class
dragon.vm.torch.nn.
GumbelSoftmax
(
tau=1,
dim=None,
inplace=False
)[source]¶ Apply the gumbel softmax function. [Jang et.al, 2016].
The GumbelSoftmax function is defined as:
\[\text{GumbelSoftmax}(x) = \frac{exp((\log(\pi_{i}) + g_{i}) / \tau)} {\sum exp((\log(\pi_{j}) + g_{i}) / \tau)} \\ \quad \\ \text{where}\quad g_{i} \sim \text{Gumbel}(0, 1) \]Examples:
m = torch.nn.GumbelSoftmax(tau=0.5, dim=1) x = torch.randn(2, 3) y = m(x)
__init__¶
GumbelSoftmax.
__init__
(
tau=1,
dim=None,
inplace=False
)[source]¶Create a
GumbelSoftmax
module.- Parameters:
- tau (Union[number, dragon.vm.torch.Tensor], default=1) – The temperature to use.
- dim (int, required) – The dimension to reduce.
- inplace (bool, optional, default=False) – Whether to do the operation in-place.