Flatten

class dragon.vm.torch.nn.Flatten(
  start_dim=1,
  end_dim=- 1
)[source]

Flatten the dimensions of input.

Examples:

m = torch.nn.Flatten()
x = torch.ones(1, 2, 4, 4)
y = m(x)
print(y.size())  # (1, 32)

__init__

Flatten.__init__(
  start_dim=1,
  end_dim=- 1
)[source]

Create a Flatten module.

Parameters:
  • start_dim (int, optional, default=0) – The start dimension to flatten.
  • end_dim (int, optional, default=-1) – The end dimension to flatten.