megengine.functional.nn.multi_head_attention

multi_head_attention(query, key, value, embed_dim, num_heads, attn_drop, out_drop, io_weight_bias, qproj_size=None, kproj_size=None, vproj_size=None, oproj_size=None, qbias=False, kbias=False, vbias=False, obias=False, bias_k=None, bias_v=None, add_zero_attn=False, key_padding_mask=None, attn_mask=None, need_weights=False, average_attn_weights=False, is_causal=False, maybe_cudnn_style_mask=False, reslink=False, training=True)[源代码]

Allows the model to jointly attend to information from different representation subspaces. See Attention Is All You Need.

\[\text{MultiHeadAttn}\big(q, k, v, W_Q, W_K, W_V, W_O\big) = \sum^{nHeads-1}_{i=0}W_{O,i}h_i\]

where \(h_i=W_{V,i}v \text{Softmax}\Big( \text{smScaler} \cdot k^TW^T_{K,i}W_{Q,i}q \Big),\text{for }i\text{ = 0 ... nHeads-1}\).

See MultiHeadAttn for more details.

Note: This API is experimental, and there is a possibility of subsequent changes.

参数
  • query (Tensor) – map a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.

  • key (Tensor) – map a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.

  • value (Tensor) – map a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.

  • embed_dim (int) – total dimension of the model.

  • num_heads (int) – parallel attention heads.

  • attn_drop (float) – probability of an element to be zeroed, used in attention matrix.

  • out_drop (float) – probability of an element to be zeroed, used in final output.

  • io_weight_bias (Optional[Tensor]) – input/output projection weight/bias all in one. The order of arrangement is: query weight, key weight, value weight, out weight, query bias, key bias, value bias, out bias, the following parameters will be used to indicate whether these items exist: qproj_size, kproj_size, vproj_size, oproj_size, qbias, kbias, vbias, obias. Note: \(Y=X@W+B\) is used here instead of \(Y=X@W^T+B\) in pytorch.

  • qproj_size (Optional[int]) – indicates the projection size of query weight in io_weight_bias, 0 indicates disabled query projection and no query projection weight.

  • kproj_size (Optional[int]) – indicates the projection size of key weight in io_weight_bias, 0 indicates disabled key projection and no key projection weight.

  • vproj_size (Optional[int]) – indicates the projection size of value weight in io_weight_bias, 0 indicates disabled value projection and no value projection weight.

  • oproj_size (Optional[int]) – indicates the projection size of out weight in io_weight_bias, 0 indicates disabled output projection and no output projection weight.

  • qbias (bool) – indicates whether there is a query bias in io_weight_bias, this parameter is only valid when qproj_size > 0.

  • kbias (bool) – indicates whether there is a key bias in io_weight_bias, this parameter is only valid when kproj_size > 0.

  • vbias (bool) – indicates whether there is a value bias in io_weight_bias, this parameter is only valid when vproj_size > 0.

  • obias (bool) – indicates whether there is a out bias in io_weight_bias, this parameter is only valid when oproj_size > 0.

  • bias_k (Optional[Tensor]) – the bias of the key and value sequences to be added at sequence dim. distinguished from kbias and vbias, bias_kv here is not kbias and vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix. Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.

  • bias_v (Optional[Tensor]) – the bias of the key and value sequences to be added at sequence dim. distinguished from kbias and vbias, bias_kv here is not kbias and vbias in the linear layer, and bias_kv here will be added to the K and V at sequence dimensions, where K and V are the matrices of key and value after projection, and K and V will be used to calculate the attention matrix. Note: Should be set to None, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.

  • add_zero_attn (bool) – if specified, adds a new batch of zeros to the key and value sequences at sequence dim. Default: False. Note: should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.

  • key_padding_mask (Optional[Tensor]) – if specified, a mask of shape \((N, S)\) indicating which elements within key to ignore for the purpose of attention (i.e. treat as “padding”). For unbatched query, shape should be \((S)\). Binary and float masks are supported. For a binary mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding key value.

  • attn_mask (Optional[Tensor]) – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.

  • need_weights (bool) – indicates whether to return the attention weight, which is the output result of softmax. Default: False

  • average_attn_weights (bool) – if true, indicates that the returned attn_weights should be averaged across heads. Otherwise, attn_weights are provided separately per head. Note that this flag only has an effect when need_weights=True. Default: False (i.e. average weights across heads)

  • is_causal (bool) – if specified, applies a causal mask as attention mask. Default: False Warning: is_causal provides a hint that attn_mask is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.

  • maybe_cudnn_style_mask (bool) – if specified, applies a cudnn style mask as attention mask. Default: False Note: In the cudnn style, the shape of the attn_mask is \((2, L)\), and the shape of the key_padding_mask is \((2, N)\). Warning: like is_causal, maybe_cudnn_style_mask provides a hint that attn_mask and key_padding_mask is a cudnn style mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. In addition, if the _merge_masks function returns merge_type=cudnn_style_mask, please ensure that other conditions are correct so that it can run the implementation of cudnn, otherwise an error will be reported.

  • reslink (bool) – add input query to final output. Note: It is only valid if the input query is the same as the shape of the output. Should be set to False, and configuration of this parameter is not supported now. The reason is that there is only cudnn implementation now, and we may try to loosen this option after submitting the commit that adds MHA proxy implementation.

  • training (bool) – will apply dropout if is True.

Outputs:
  • out[0]=attn_output - Attention outputs of shape \((N, L, E)\), where \(L\) is the target sequence length, \(N\) is the batch size, and \(E\) is the embedding dimension embed_dim.

  • out[1]=attn_output_weights - Only returned when need_weights=True. If average_attn_weights=True, returns attention weights averaged across heads of shape \((L, S)\) when input is unbatched or \((N, L, S)\), where \(N\) is the batch size, \(L\) is the target sequence length, and \(S\) is the source sequence length. If average_attn_weights=False, returns attention weights per head of shape \((\text{num\_heads}, L, S)\) when input is unbatched or \((N * \text{num\_heads}, L, S)\).

  • out[2]=mask_reversespace - Used to save the dropout mask needed for backward propagation.,

  • out[3]=othr_reversespace - Used to save the intermediate results that need to be used in backward propagation.,