argmax

dragon.vm.torch.argmax(
  input,
  dim,
  keepdim=False,
  out=None
)[source]

Return the index of maximum elements along the given dimension.

dim could be negative:

# A negative dimension is the last-k dimension
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(torch.argmax(x, dim=1))
print(torch.argmax(x, dim=-1))  # Equivalent
Parameters:
  • input (dragon.vm.torch.Tensor) The input tensor.
  • dim (int) The dimension to reduce.
  • keepdim (bool, optional, default=False) Keep the reduced dimension or not.
  • out (dragon.vm.torch.Tensor, optional) The output tensor.
Returns:

dragon.vm.torch.Tensor The index of maximum elements.