multi_head_attention_forward

dragon.vm.torch.nn.functional.multi_head_attention_forward(
  query,
  key,
  value,
  embed_dim_to_check,
  num_heads,
  in_proj_weight,
  in_proj_bias,
  out_proj_weight,
  out_proj_bias,
  dropout_p=0.0,
  training=True,
  need_weights=True,
  key_padding_mask=None,
  attn_mask=None,
  use_separate_proj_weight=False,
  q_proj_weight=None,
  k_proj_weight=None,
  v_proj_weight=None
)[source]

Apply the multihead attention to input. [Vaswani et.al, 2017].

Parameters:
  • query (dragon.vm.torch.Tensor) – The query tensor.
  • key (dragon.vm.torch.Tensor) – The key tensor.
  • value (dragon.vm.torch.Tensor) – The value tensor.
  • embed_dim_to_check (int) – The dimension of input embeddings.
  • num_heads (int) – The number of parallel heads.
  • in_proj_weight (dragon.vm.torch.Tensor) – The weight tensor for input projection.
  • in_proj_bias (dragon.vm.torch.Tensor) – The bias tensor for input projection.
  • out_proj_weight (dragon.vm.torch.Tensor) – The weight tensor for output projection.
  • out_proj_bias (dragon.vm.torch.Tensor) – The bias tensor for output projection.
  • dropout_p (float, optional, default=0.) – The probability to set the attention to zero.
  • training (bool, optional, default=True) – Apply dropout if True.
  • need_weights (bool, optional, default=True) – Return the attention weights or not.
  • key_padding_mask (dragon.vm.torch.Tensor, optional) – The mask to prevents attention to padded keys.
  • attn_mask (dragon.vm.torch.Tensor, optional) – The mask to prevents attention to certain positions.
  • use_separate_proj_weight (bool, optional, default=False) – Provide separate projection weights or not.
  • q_proj_weight (dragon.vm.torch.Tensor, optional) – The separate weight tensor for query projection.
  • k_proj_weight (dragon.vm.torch.Tensor, optional) – The separate weight tensor for key projection.
  • v_proj_weight (dragon.vm.torch.Tensor, optional) – The separate weight tensor for value projection.
Returns:

Tuple[dragon.vm.torch.Tensor, dragon.vm.torch.Tensor] – The output and attention weights tensor.