chunk

dragon.vm.torch.chunk(
  tensor,
  chunks,
  dim=0
)[source]

Split input into a specific number of chunks.

Examples:

x = torch.tensor([[1, 2], [3, 4], [5, 6]])
# Shape: (3, 2) -> (2, 2), (1, 2)
print(torch.chunk(x, chunks=2))

The dim can be negative representing the last-k axis:

x = torch.tensor([[1, 2], [3, 4], [5, 6]])
print(torch.chunk(x, 2, dim=1))
print(torch.chunk(x, 2, dim=-1))  # Equivalent
Parameters:
  • tensor (dragon.vm.torch.Tensor) – The input tensor.
  • chunks (int) – The number of chunks to split.
  • dim (int, optional, default=0) – The dimension to split.
Returns:

Sequence[dragon.vm.torch.Tensor] – The output tensors.