MultiheadAttention¶
-
class
torch.nn.
MultiheadAttention
(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source]¶ Allows the model to jointly attend to information from different representation subspaces. See Attention Is All You Need
where .
- Parameters
embed_dim – total dimension of the model.
num_heads – parallel attention heads.
dropout – a Dropout layer on attn_output_weights. Default: 0.0.
bias – add bias as module parameter. Default: True.
add_bias_kv – add bias to the key and value sequences at dim=0.
add_zero_attn – add a new batch of zeros to the key and value sequences at dim=1.
kdim – total number of features in key. Default: None.
vdim – total number of features in value. Default: None.
batch_first – If
True
, then the input and output tensors are provided as (batch, seq, feature). Default:False
(seq, batch, feature).
Note that if
kdim
andvdim
are None, they will be set toembed_dim
such that query, key, and value have the same number of features.Examples:
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
-
forward
(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)[source]¶ - Parameters
key, value (query,) – map a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.
key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored
need_weights – output attn_output_weights.
attn_mask – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.
- Shapes for inputs:
query: where L is the target sequence length, N is the batch size, E is the embedding dimension. if
batch_first
isTrue
.key: , where S is the source sequence length, N is the batch size, E is the embedding dimension. if
batch_first
isTrue
.value: where S is the source sequence length, N is the batch size, E is the embedding dimension. if
batch_first
isTrue
.key_padding_mask: where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of
True
will be ignored while the position with the value ofFalse
will be unchanged.attn_mask: if a 2D mask: where L is the target sequence length, S is the source sequence length.
If a 3D mask: where N is the batch size, L is the target sequence length, S is the source sequence length.
attn_mask
ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions withTrue
is not allowed to attend whileFalse
values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.
- Shapes for outputs:
attn_output: where L is the target sequence length, N is the batch size, E is the embedding dimension. if
batch_first
isTrue
.attn_output_weights: where N is the batch size, L is the target sequence length, S is the source sequence length.