split

dragon.vm.torch.split(
  tensor,
  split_size_or_sections,
  dim=0
)[source]

Split input into chunks along the given dimension.

Either size of every chunk or each chunk will be accepted:

x = torch.tensor([1, 2, 3, 4, 5, 6])
# Shape: (6,) -> (4,), (2,)
print(torch.split(x, split_size_or_sections=4))
# Shape: (6,) -> (5,), (1,)
print(torch.split(x, split_size_or_sections=(5, 1)))

The dim can be negative representing the last-k axis:

x = torch.tensor([[1, 2], [3, 4], [5, 6]])
print(torch.split(x, 2, dim=1))
print(torch.split(x, 2, dim=-1))  # Equivalent
Parameters:
  • tensor (dragon.vm.torch.Tensor) – The input tensor.
  • split_size_or_sections (Union[int, Sequence[int]) – The number or size of chunks.
  • dim (int, optional, default=0) – The dimension to split.
Returns:

Sequence[dragon.vm.torch.Tensor] – The output tensors.