split

dragon.split(
  inputs,
  num_or_size_splits,
  axis=0,
  slice_points=None,
  **kwargs
)[source]

Split input into chunks along the given axis.

Either number or size of splits will be accepted:

x = dragon.constant([[1, 2], [3, 4], [5, 6]])
# Shape: (3, 2) -> (2, 2), (1, 2)
print(dragon.split(x, num_or_size_splits=2))
# Shape: (3, 2) -> (1, 2), (2, 2)
print(dragon.split(x, num_or_size_splits=(1, 2)))

The axis can be negative representing the last-k axis:

x = dragon.constant([[1, 2], [3, 4], [5, 6]])
print(dragon.split(x, 2, axis=1))
print(dragon.split(x, 2, axis=-1))  # Equivalent

Optionally, use slice_points to hint the splits:

x = dragon.constant([[1, 2], [3, 4], [5, 6]])
# Shape: (3, 2) -> (1, 2), (1, 2), (1, 2)
print(dragon.split(x, 3, slice_points=[1, 2]))
Parameters:
  • inputs (dragon.Tensor) – The input tensor.
  • num_or_size_splits (Union[int, Sequence[int]]) – The number or size of chunks.
  • axis (int, optional, default=0) – The axis to split, can be negative.
  • slice_points (Sequence[int], optional) – The optional slice points.
Returns:

Sequence[dragon.Tensor] – The outputs.