Shortcuts

MultiheadAttention

class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source]

Allows the model to jointly attend to information from different representation subspaces. See Attention Is All You Need

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O

where headi=Attention(QWiQ,KWiK,VWiV)head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V).

Parameters
  • embed_dim – total dimension of the model.

  • num_heads – parallel attention heads.

  • dropout – a Dropout layer on attn_output_weights. Default: 0.0.

  • bias – add bias as module parameter. Default: True.

  • add_bias_kv – add bias to the key and value sequences at dim=0.

  • add_zero_attn – add a new batch of zeros to the key and value sequences at dim=1.

  • kdim – total number of features in key. Default: None.

  • vdim – total number of features in value. Default: None.

  • batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

Note that if kdim and vdim are None, they will be set to embed_dim such that query, key, and value have the same number of features.

Examples:

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)[source]
Parameters
  • key, value (query,) – map a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.

  • key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored

  • need_weights – output attn_output_weights.

  • attn_mask – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.

Shapes for inputs:
  • query: (L,N,E)(L, N, E) where L is the target sequence length, N is the batch size, E is the embedding dimension. (N,L,E)(N, L, E) if batch_first is True.

  • key: (S,N,E)(S, N, E), where S is the source sequence length, N is the batch size, E is the embedding dimension. (N,S,E)(N, S, E) if batch_first is True.

  • value: (S,N,E)(S, N, E) where S is the source sequence length, N is the batch size, E is the embedding dimension. (N,S,E)(N, S, E) if batch_first is True.

  • key_padding_mask: (N,S)(N, S) where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.

  • attn_mask: if a 2D mask: (L,S)(L, S) where L is the target sequence length, S is the source sequence length.

    If a 3D mask: (Nnum_heads,L,S)(N\cdot\text{num\_heads}, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.

Shapes for outputs:
  • attn_output: (L,N,E)(L, N, E) where L is the target sequence length, N is the batch size, E is the embedding dimension. (N,L,E)(N, L, E) if batch_first is True.

  • attn_output_weights: (N,L,S)(N, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources