scatter_add¶
dragon.vm.torch.
scatter_add
(
input,
dim,
index,
src,
out=None
)[source]¶Add elements along the given dimension of index.
Number of dimensions of
input
,index
, andsrc
should be same. For 3-d input, output is updated as:out[index[i, j, k], j, k] += src[i, j, k] # ``dim`` is 0 out[i, index[i, j, k], k] += src[i, j, k] # ``dim`` is 1 out[i, j, index[i, j, k]] += src[i, j, k] # ``dim`` is 2
Examples:
y = torch.tensor([[1, 2], [3, 4]]) x = torch.tensor([[5, 6], [7, 8]]) index = torch.tensor([[0, 0], [0, 0]]) print(torch.scatter_add(y, 0, index, x)) # [[13, 16], [3, 4]] print(torch.scatter_add(y, 1, index, x)) # [[12, 2], [18, 4]] print(torch.scatter_add(y, 0, index, 8)) # [[17, 18], [3, 4]]
- Parameters:
- input (dragon.vm.torch.Tensor) – The input tensor.
- dim (int) – The dimension of index values.
- index (dragon.vm.torch.Tensor) – The index tensor.
- src (Union[dragon.vm.torch.Tensor, number]) – The tensor to add from.
- out (dragon.vm.torch.Tensor, optional) – The output tensor.