multi_head_attention_forward¶
- dragon.vm.torch.nn.functional.- multi_head_attention_forward(
 query,
 key,
 value,
 embed_dim_to_check,
 num_heads,
 in_proj_weight,
 in_proj_bias,
 out_proj_weight,
 out_proj_bias,
 dropout_p=0.0,
 training=True,
 need_weights=True,
 key_padding_mask=None,
 attn_mask=None,
 use_separate_proj_weight=False,
 q_proj_weight=None,
 k_proj_weight=None,
 v_proj_weight=None
 )[source]¶
- Apply the multihead attention to input. [Vaswani et.al, 2017]. - Parameters:
- query (dragon.vm.torch.Tensor) – The query tensor.
- key (dragon.vm.torch.Tensor) – The key tensor.
- value (dragon.vm.torch.Tensor) – The value tensor.
- embed_dim_to_check (int) – The dimension of input embeddings.
- num_heads (int) – The number of parallel heads.
- in_proj_weight (dragon.vm.torch.Tensor) – The weight tensor for input projection.
- in_proj_bias (dragon.vm.torch.Tensor) – The bias tensor for input projection.
- out_proj_weight (dragon.vm.torch.Tensor) – The weight tensor for output projection.
- out_proj_bias (dragon.vm.torch.Tensor) – The bias tensor for output projection.
- dropout_p (float, optional, default=0.) – The probability to set the attention to zero.
- training (bool, optional, default=True) – Apply dropout if True.
- need_weights (bool, optional, default=True) – Return the attention weights or not.
- key_padding_mask (dragon.vm.torch.Tensor, optional) – The mask to prevents attention to padded keys.
- attn_mask (dragon.vm.torch.Tensor, optional) – The mask to prevents attention to certain positions.
- use_separate_proj_weight (bool, optional, default=False) – Provide separate projection weights or not.
- q_proj_weight (dragon.vm.torch.Tensor, optional) – The separate weight tensor for query projection.
- k_proj_weight (dragon.vm.torch.Tensor, optional) – The separate weight tensor for key projection.
- v_proj_weight (dragon.vm.torch.Tensor, optional) – The separate weight tensor for value projection.
 
 - Returns:
- Tuple[dragon.vm.torch.Tensor, dragon.vm.torch.Tensor] – The output and attention weights tensor. 
 - See also 
