gather¶
dragon.vm.torch.
gather
(
input,
dim,
index,
out=None
)[source]¶Gather elements along the given dimension of index.
Number of dimensions of
input
,index
should be same. For 3-d input, output is gathered as:out[i, j, k] = input[index[i, j, k], j, k] out[i, j, k] = input[i, index[i, j, k], k] out[i, j, k] = input[i, j, index[i, j, k]]
Examples:
x = torch.tensor([[1, 2], [3, 4]]) index = torch.tensor([[0, 0], [0, 1]]) print(torch.gather(x, 0, index)) # [[1, 2], [1, 4]] print(torch.gather(x, 1, index)) # [[1, 1], [3, 4]]
- Parameters:
- input (dragon.vm.torch.Tensor) – The input tensor.
- dim (int) – The dimension of index values.
- index (dragon.vm.torch.Tensor) – The index tensor.
- out (dragon.vm.torch.Tensor, optional) – The output tensor.