scatter¶
dragon.vm.torch.
scatter
(
input,
dim,
index,
src,
out=None
)[source]¶Update elements along the given dimension of index.
Number of dimensions of
input
,index
, andsrc
should be same. For 3-d input, output is updated as:out[index[i, j, k], j, k] = src[i, j, k] # ``dim`` is 0 out[i, index[i, j, k], k] = src[i, j, k] # ``dim`` is 1 out[i, j, index[i, j, k]] = src[i, j, k] # ``dim`` is 2
Examples:
y = torch.tensor([[1, 2], [3, 4]]) x = torch.tensor([[5, 6], [7, 8]]) index = torch.tensor([[0, 1], [1, 0]]) print(torch.scatter(y, 0, index, x)) # [[5, 8], [7, 6]] print(torch.scatter(y, 1, index, x)) # [[5, 6], [8, 7]] print(torch.scatter(y, 0, index, 8)) # [[8, 8], [8, 8]]
- Parameters:
- input (dragon.vm.torch.Tensor) – The input tensor.
- dim (int) – The dimension of index values.
- index (dragon.vm.torch.Tensor) – The index tensor.
- src (Union[dragon.vm.torch.Tensor, number]) – The tensor to update from.
- out (dragon.vm.torch.Tensor, optional) – The output tensor.