chunk

dragon.vm.torch.chunk(
  tensor,
  chunks,
  dim=0,
  copy=True
)[source]

Split input into a specific number of chunks.

Examples:

x = torch.tensor([[1, 2], [3, 4], [5, 6]])
# Shape: (3, 2) -> (2, 2), (1, 2)
print(torch.chunk(x, chunks=2))

dim could be negative:

x = torch.tensor([[1, 2], [3, 4], [5, 6]])
print(torch.chunk(x, 2, dim=1))
print(torch.chunk(x, 2, dim=-1))  # Equivalent
Parameters:
  • tensor (dragon.vm.torch.Tensor) The input tensor.
  • chunks (int) The number of chunks to split.
  • dim (int, optional, default=0) The dimension to split.
  • copy (bool, optional, default=True) Copy or create the views of input.
Returns:

Sequence[dragon.vm.torch.Tensor] The output tensors.