chunk¶
dragon.vm.torch.
chunk
(
tensor,
chunks,
dim=0,
copy=True
)[source]¶Split input into a specific number of chunks.
Examples:
x = torch.tensor([[1, 2], [3, 4], [5, 6]]) # Shape: (3, 2) -> (2, 2), (1, 2) print(torch.chunk(x, chunks=2))
dim
could be negative:x = torch.tensor([[1, 2], [3, 4], [5, 6]]) print(torch.chunk(x, 2, dim=1)) print(torch.chunk(x, 2, dim=-1)) # Equivalent
- Parameters:
- tensor (dragon.vm.torch.Tensor) – The input tensor.
- chunks (int) – The number of chunks to split.
- dim (int, optional, default=0) – The dimension to split.
- copy (bool, optional, default=True) – Copy or create the views of input.
- Returns:
Sequence[dragon.vm.torch.Tensor] – The output tensors.