split

dragon.vm.torch.split(
  tensor,
  split_size_or_sections,
  dim=0,
  copy=True
)[source]

Split input into chunks along the given dimension.

Either size of every chunk or each chunk will be accepted:

x = torch.tensor([1, 2, 3, 4, 5, 6])
# Shape: (6,) -> (4,), (2,)
print(torch.split(x, split_size_or_sections=4))
# Shape: (6,) -> (5,), (1,)
print(torch.split(x, split_size_or_sections=(5, 1)))

dim can be negative:

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(torch.split(x, 2, dim=1))
print(torch.split(x, 2, dim=-1))  # Equivalent
Parameters:
  • tensor (dragon.vm.torch.Tensor) The input tensor.
  • split_size_or_sections (Union[int, Sequence[int]) The number or size of chunks.
  • dim (int, optional, default=0) The dimension to split.
  • copy (bool, optional, default=True) Copy or create the views of input.
Returns:

Sequence[dragon.vm.torch.Tensor] The output tensors.