multi_head_attention_forward¶
dragon.vm.torch.nn.functional.
multi_head_attention_forward
(
query,
key,
value,
embed_dim_to_check,
num_heads,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
dropout_p=0.0,
training=True,
need_weights=True,
key_padding_mask=None,
attn_mask=None,
use_separate_proj_weight=False,
q_proj_weight=None,
k_proj_weight=None,
v_proj_weight=None
)[source]¶Apply the multihead attention to input. [Vaswani et.al, 2017].
- Parameters:
- query (dragon.vm.torch.Tensor) – The query tensor.
- key (dragon.vm.torch.Tensor) – The key tensor.
- value (dragon.vm.torch.Tensor) – The value tensor.
- embed_dim_to_check (int) – The dimension of input embeddings.
- num_heads (int) – The number of parallel heads.
- in_proj_weight (dragon.vm.torch.Tensor) – The weight tensor for input projection.
- in_proj_bias (dragon.vm.torch.Tensor) – The bias tensor for input projection.
- out_proj_weight (dragon.vm.torch.Tensor) – The weight tensor for output projection.
- out_proj_bias (dragon.vm.torch.Tensor) – The bias tensor for output projection.
- dropout_p (float, optional, default=0.) – The probability to set the attention to zero.
- training (bool, optional, default=True) – Apply dropout if
True
. - need_weights (bool, optional, default=True) – Return the attention weights or not.
- key_padding_mask (dragon.vm.torch.Tensor, optional) – The mask to prevents attention to padded keys.
- attn_mask (dragon.vm.torch.Tensor, optional) – The mask to prevents attention to certain positions.
- use_separate_proj_weight (bool, optional, default=False) – Provide separate projection weights or not.
- q_proj_weight (dragon.vm.torch.Tensor, optional) – The separate weight tensor for query projection.
- k_proj_weight (dragon.vm.torch.Tensor, optional) – The separate weight tensor for key projection.
- v_proj_weight (dragon.vm.torch.Tensor, optional) – The separate weight tensor for value projection.
- Returns:
Tuple[dragon.vm.torch.Tensor, dragon.vm.torch.Tensor] – The output and attention weights tensor.
See also