argmax

dragon.vm.torch.argmax(
  input,
  dim=None,
  keepdim=False,
  out=None
)[source]

Return the index of maximum elements along the given dimension.

The argument dim could be negative or None:

x = torch.tensor([[1, 2, 3], [4, 5, 6]])

# A negative ``dim`` is the last-k axis
print(torch.argmax(x, 1))
print(torch.argmax(x, -1))  # Equivalent

# If ``dim`` is None, the vector-style reduction
# will be applied to return a scalar index
print(torch.argmax(x))  # 5
Parameters:
  • input (dragon.vm.torch.Tensor) – The input tensor.
  • dim (int, optional) – The dimension to reduce.
  • keepdim (bool, optional, default=False) – Keep the reduced dimension or not.
  • out (dragon.vm.torch.Tensor, optional) – The optional output tensor.
Returns:

dragon.vm.torch.Tensor – The index of maximum elements.