gather

dragon.vm.torch.gather(
  input,
  dim,
  index,
  out=None
)[source]

Gather elements along the given dimension of index.

Number of dimensions of input, index should be same. For 3-d input, output is gathered as:

out[i, j, k] = input[index[i, j, k], j, k]
out[i, j, k] = input[i, index[i, j, k], k]
out[i, j, k] = input[i, j, index[i, j, k]]

Examples:

x = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [0, 1]])
print(torch.gather(x, 0, index))  # [[1, 2], [1, 4]]
print(torch.gather(x, 1, index))  # [[1, 1], [3, 4]]
Parameters: