gather_elements

dragon.gather_elements(
  inputs,
  axis=0,
  **kwargs
)[source]

Gather elements along the given axis of index.

Number of dimensions of input and index should be same. For 3-d input, output is gathered as:

out[i, j, k] = input[index[i, j, k], j, k]
out[i, j, k] = input[i, index[i, j, k], k]
out[i, j, k] = input[i, j, index[i, j, k]]

Examples:

x = dragon.constant([[1, 2], [3, 4]])
index = dragon.constant([[0, 0], [0, 1]])
print(dragon.gather_elements([x, index], axis=0))  # [[1, 2], [1, 4]]
print(dragon.gather_elements([x, index], axis=1))  # [[1, 1], [3, 4]]
Parameters:
  • inputs (Sequence[dragon.Tensor]) – The input and index tensor.
  • axis (int, optional, default=0) – The axis of index values.
Returns:

dragon.Tensor – The output tensor.