split¶
dragon.vm.torch.
split
(
tensor,
split_size_or_sections,
dim=0,
copy=True
)[source]¶Split input into chunks along the given dimension.
Either size of every chunk or each chunk will be accepted:
x = torch.tensor([1, 2, 3, 4, 5, 6]) # Shape: (6,) -> (4,), (2,) print(torch.split(x, split_size_or_sections=4)) # Shape: (6,) -> (5,), (1,) print(torch.split(x, split_size_or_sections=(5, 1)))
dim
can be negative:x = torch.tensor([[1, 2, 3], [4, 5, 6]]) print(torch.split(x, 2, dim=1)) print(torch.split(x, 2, dim=-1)) # Equivalent
- Parameters:
- tensor (dragon.vm.torch.Tensor) – The input tensor.
- split_size_or_sections (Union[int, Sequence[int]) – The number or size of chunks.
- dim (int, optional, default=0) – The dimension to split.
- copy (bool, optional, default=True) – Copy or create the views of input.
- Returns:
Sequence[dragon.vm.torch.Tensor] – The output tensors.