GumbelSoftmax

class dragon.vm.torch.nn.GumbelSoftmax(
  tau=1,
  dim=None,
  inplace=False
)[source]

Apply the gumbel softmax function. [Jang et.al, 2016].

The GumbelSoftmax function is defined as:

\[\text{GumbelSoftmax}(x) = \frac{exp((\log(\pi_{i}) + g_{i}) / \tau)} {\sum exp((\log(\pi_{j}) + g_{i}) / \tau)} \\ \quad \\ \text{where}\quad g_{i} \sim \text{Gumbel}(0, 1) \]

Examples:

m = torch.nn.GumbelSoftmax(tau=0.5, dim=1)
x = torch.randn(2, 3)
y = m(x)

__init__

GumbelSoftmax.__init__(
  tau=1,
  dim=None,
  inplace=False
)[source]

Create a GumbelSoftmax module.

Parameters:
  • tau (Union[number, dragon.vm.torch.Tensor], default=1) – The temperature to use.
  • dim (int, required) – The dimension to reduce.
  • inplace (bool, optional, default=False) – Whether to do the operation in-place.