Add head_mask and decoder_head_mask to PyTorch LED (#9856)
* Add {decoder_,}head_mask to LED
* Fix create_custom_forward signatue in encoder
* Add head_mask to longformer
* Add head_mask to longformer to fix dependencies
of LED on Longformer.
* Not working yet
* Add mising one input in longofrmer_modeling.py
* make fix-copies
This commit is contained in:
@@ -164,6 +164,7 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
layer_head_mask=None,
|
||||||
is_index_masked=None,
|
is_index_masked=None,
|
||||||
is_index_global_attn=None,
|
is_index_global_attn=None,
|
||||||
is_global_attn=None,
|
is_global_attn=None,
|
||||||
@@ -251,6 +252,12 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||||||
|
|
||||||
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
|
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
|
||||||
|
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
assert layer_head_mask.size() == (
|
||||||
|
self.num_heads,
|
||||||
|
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs
|
||||||
|
|
||||||
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
||||||
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
|
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
|
||||||
attn_probs = attn_probs.type_as(attn_scores)
|
attn_probs = attn_probs.type_as(attn_scores)
|
||||||
@@ -288,6 +295,7 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||||||
global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
|
global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
max_num_global_attn_indices=max_num_global_attn_indices,
|
max_num_global_attn_indices=max_num_global_attn_indices,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
||||||
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
||||||
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
||||||
@@ -595,6 +603,7 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
max_num_global_attn_indices,
|
max_num_global_attn_indices,
|
||||||
|
layer_head_mask,
|
||||||
is_local_index_global_attn_nonzero,
|
is_local_index_global_attn_nonzero,
|
||||||
is_index_global_attn_nonzero,
|
is_index_global_attn_nonzero,
|
||||||
is_local_index_no_global_attn_nonzero,
|
is_local_index_no_global_attn_nonzero,
|
||||||
@@ -656,6 +665,18 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||||||
global_attn_scores, dim=-1, dtype=torch.float32
|
global_attn_scores, dim=-1, dtype=torch.float32
|
||||||
) # use fp32 for numerical stability
|
) # use fp32 for numerical stability
|
||||||
|
|
||||||
|
# apply layer head masking
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
assert layer_head_mask.size() == (
|
||||||
|
self.num_heads,
|
||||||
|
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
global_attn_probs_float = layer_head_mask.view(1, -1, 1, 1) * global_attn_probs_float.view(
|
||||||
|
batch_size, self.num_heads, max_num_global_attn_indices, seq_len
|
||||||
|
)
|
||||||
|
global_attn_probs_float = global_attn_probs_float.view(
|
||||||
|
batch_size * self.num_heads, max_num_global_attn_indices, seq_len
|
||||||
|
)
|
||||||
|
|
||||||
global_attn_probs = F.dropout(
|
global_attn_probs = F.dropout(
|
||||||
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
|
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
|
||||||
)
|
)
|
||||||
@@ -686,6 +707,7 @@ class LEDEncoderAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
is_index_masked: Optional[torch.Tensor] = None,
|
is_index_masked: Optional[torch.Tensor] = None,
|
||||||
is_index_global_attn: Optional[torch.Tensor] = None,
|
is_index_global_attn: Optional[torch.Tensor] = None,
|
||||||
is_global_attn: Optional[bool] = None,
|
is_global_attn: Optional[bool] = None,
|
||||||
@@ -696,6 +718,7 @@ class LEDEncoderAttention(nn.Module):
|
|||||||
self_outputs = self.longformer_self_attn(
|
self_outputs = self.longformer_self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
is_index_masked=is_index_masked,
|
is_index_masked=is_index_masked,
|
||||||
is_index_global_attn=is_index_global_attn,
|
is_index_global_attn=is_index_global_attn,
|
||||||
is_global_attn=is_global_attn,
|
is_global_attn=is_global_attn,
|
||||||
@@ -744,6 +767,7 @@ class LEDDecoderAttention(nn.Module):
|
|||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
@@ -810,6 +834,12 @@ class LEDDecoderAttention(nn.Module):
|
|||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
assert layer_head_mask.size() == (
|
||||||
|
self.num_heads,
|
||||||
|
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# this operation is a bit akward, but it's required to
|
# this operation is a bit akward, but it's required to
|
||||||
@@ -859,6 +889,7 @@ class LEDEncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
layer_head_mask: torch.Tensor,
|
||||||
is_index_masked=None,
|
is_index_masked=None,
|
||||||
is_index_global_attn=None,
|
is_index_global_attn=None,
|
||||||
is_global_attn=None,
|
is_global_attn=None,
|
||||||
@@ -869,11 +900,14 @@ class LEDEncoderLayer(nn.Module):
|
|||||||
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
attn_outputs = self.self_attn(
|
attn_outputs = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
is_index_masked=is_index_masked,
|
is_index_masked=is_index_masked,
|
||||||
is_index_global_attn=is_index_global_attn,
|
is_index_global_attn=is_index_global_attn,
|
||||||
is_global_attn=is_global_attn,
|
is_global_attn=is_global_attn,
|
||||||
@@ -931,6 +965,8 @@ class LEDDecoderLayer(nn.Module):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
@@ -943,6 +979,10 @@ class LEDDecoderLayer(nn.Module):
|
|||||||
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||||
|
`(config.encoder_attention_heads,)`.
|
||||||
|
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||||
|
size `(config.encoder_attention_heads,)`.
|
||||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||||
output_attentions (:obj:`bool`): Whether the base model outputs attentions.
|
output_attentions (:obj:`bool`): Whether the base model outputs attentions.
|
||||||
This requires the attentions tensor to be reshaped in this function.
|
This requires the attentions tensor to be reshaped in this function.
|
||||||
@@ -957,6 +997,7 @@ class LEDDecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
past_key_value=self_attn_past_key_value,
|
past_key_value=self_attn_past_key_value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
@@ -975,6 +1016,7 @@ class LEDDecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=encoder_layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=cross_attn_past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@@ -1155,6 +1197,17 @@ class LEDSeq2SeqModelOutput(ModelOutput):
|
|||||||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||||
in the sequence.
|
in the sequence.
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
last_hidden_state: torch.FloatTensor = None
|
last_hidden_state: torch.FloatTensor = None
|
||||||
@@ -1166,6 +1219,8 @@ class LEDSeq2SeqModelOutput(ModelOutput):
|
|||||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None
|
||||||
|
decoder_head_mask: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -1221,6 +1276,17 @@ class LEDSeq2SeqLMOutput(ModelOutput):
|
|||||||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||||
in the sequence.
|
in the sequence.
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
@@ -1233,6 +1299,8 @@ class LEDSeq2SeqLMOutput(ModelOutput):
|
|||||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None
|
||||||
|
decoder_head_mask: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -1288,6 +1356,17 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
|
|||||||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||||
in the sequence.
|
in the sequence.
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
@@ -1300,6 +1379,8 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
|
|||||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None
|
||||||
|
decoder_head_mask: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -1357,6 +1438,17 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
|||||||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||||
in the sequence.
|
in the sequence.
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
@@ -1370,6 +1462,8 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
|||||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None
|
||||||
|
decoder_head_mask: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
LED_START_DOCSTRING = r"""
|
LED_START_DOCSTRING = r"""
|
||||||
@@ -1442,6 +1536,17 @@ LED_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
- 0 for local attention (a sliding window attention),
|
- 0 for local attention (a sliding window attention),
|
||||||
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||||
@@ -1582,6 +1687,7 @@ class LEDEncoder(LEDPreTrainedModel):
|
|||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
global_attention_mask=None,
|
global_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
@@ -1615,6 +1721,11 @@ class LEDEncoder(LEDPreTrainedModel):
|
|||||||
|
|
||||||
- 0 for local attention (a sliding window attention),
|
- 0 for local attention (a sliding window attention),
|
||||||
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||||
@@ -1686,7 +1797,12 @@ class LEDEncoder(LEDPreTrainedModel):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
all_global_attentions = () if (output_attentions and is_global_attn) else None
|
all_global_attentions = () if (output_attentions and is_global_attn) else None
|
||||||
|
|
||||||
for encoder_layer in self.layers:
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
@@ -1707,6 +1823,7 @@ class LEDEncoder(LEDPreTrainedModel):
|
|||||||
create_custom_forward(encoder_layer),
|
create_custom_forward(encoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask[idx] if head_mask is not None else None,
|
||||||
is_index_masked,
|
is_index_masked,
|
||||||
is_index_global_attn,
|
is_index_global_attn,
|
||||||
)
|
)
|
||||||
@@ -1714,6 +1831,7 @@ class LEDEncoder(LEDPreTrainedModel):
|
|||||||
layer_outputs = encoder_layer(
|
layer_outputs = encoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
is_index_masked=is_index_masked,
|
is_index_masked=is_index_masked,
|
||||||
is_index_global_attn=is_index_global_attn,
|
is_index_global_attn=is_index_global_attn,
|
||||||
is_global_attn=is_global_attn,
|
is_global_attn=is_global_attn,
|
||||||
@@ -1787,6 +1905,8 @@ class LEDDecoder(LEDPreTrainedModel):
|
|||||||
global_attention_mask=None,
|
global_attention_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
@@ -1833,6 +1953,19 @@ class LEDDecoder(LEDPreTrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
decoding.
|
decoding.
|
||||||
@@ -1910,6 +2043,12 @@ class LEDDecoder(LEDPreTrainedModel):
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions else None
|
all_cross_attentions = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -1942,6 +2081,8 @@ class LEDDecoder(LEDPreTrainedModel):
|
|||||||
combined_attention_mask,
|
combined_attention_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
head_mask[idx] if head_mask is not None else None,
|
||||||
|
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1950,6 +2091,8 @@ class LEDDecoder(LEDPreTrainedModel):
|
|||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -2027,6 +2170,8 @@ class LEDModel(LEDPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
global_attention_mask=None,
|
global_attention_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
@@ -2049,6 +2194,7 @@ class LEDModel(LEDPreTrainedModel):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
global_attention_mask=global_attention_mask,
|
global_attention_mask=global_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -2069,6 +2215,8 @@ class LEDModel(LEDPreTrainedModel):
|
|||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
encoder_hidden_states=encoder_outputs[0],
|
encoder_hidden_states=encoder_outputs[0],
|
||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
|
head_mask=decoder_head_mask,
|
||||||
|
encoder_head_mask=head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -2148,6 +2296,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
global_attention_mask=None,
|
global_attention_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
@@ -2198,6 +2348,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
global_attention_mask=global_attention_mask,
|
global_attention_mask=global_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
@@ -2231,7 +2383,14 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
self,
|
||||||
|
decoder_input_ids,
|
||||||
|
past=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
use_cache=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# cut decoder_input_ids if past is used
|
# cut decoder_input_ids if past is used
|
||||||
if past is not None:
|
if past is not None:
|
||||||
@@ -2243,6 +2402,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2290,6 +2450,8 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
global_attention_mask=None,
|
global_attention_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -2320,6 +2482,8 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
global_attention_mask=global_attention_mask,
|
global_attention_mask=global_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
@@ -2394,6 +2558,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
global_attention_mask=None,
|
global_attention_mask=None,
|
||||||
start_positions=None,
|
start_positions=None,
|
||||||
@@ -2425,6 +2591,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
global_attention_mask=global_attention_mask,
|
global_attention_mask=global_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
|
|||||||
@@ -553,6 +553,7 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
layer_head_mask=None,
|
||||||
is_index_masked=None,
|
is_index_masked=None,
|
||||||
is_index_global_attn=None,
|
is_index_global_attn=None,
|
||||||
is_global_attn=None,
|
is_global_attn=None,
|
||||||
@@ -640,6 +641,12 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
|
|
||||||
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
|
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
|
||||||
|
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
assert layer_head_mask.size() == (
|
||||||
|
self.num_heads,
|
||||||
|
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs
|
||||||
|
|
||||||
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
||||||
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
|
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
|
||||||
attn_probs = attn_probs.type_as(attn_scores)
|
attn_probs = attn_probs.type_as(attn_scores)
|
||||||
@@ -677,6 +684,7 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
|
global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
max_num_global_attn_indices=max_num_global_attn_indices,
|
max_num_global_attn_indices=max_num_global_attn_indices,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
||||||
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
||||||
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
||||||
@@ -984,6 +992,7 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
max_num_global_attn_indices,
|
max_num_global_attn_indices,
|
||||||
|
layer_head_mask,
|
||||||
is_local_index_global_attn_nonzero,
|
is_local_index_global_attn_nonzero,
|
||||||
is_index_global_attn_nonzero,
|
is_index_global_attn_nonzero,
|
||||||
is_local_index_no_global_attn_nonzero,
|
is_local_index_no_global_attn_nonzero,
|
||||||
@@ -1045,6 +1054,18 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
global_attn_scores, dim=-1, dtype=torch.float32
|
global_attn_scores, dim=-1, dtype=torch.float32
|
||||||
) # use fp32 for numerical stability
|
) # use fp32 for numerical stability
|
||||||
|
|
||||||
|
# apply layer head masking
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
assert layer_head_mask.size() == (
|
||||||
|
self.num_heads,
|
||||||
|
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
global_attn_probs_float = layer_head_mask.view(1, -1, 1, 1) * global_attn_probs_float.view(
|
||||||
|
batch_size, self.num_heads, max_num_global_attn_indices, seq_len
|
||||||
|
)
|
||||||
|
global_attn_probs_float = global_attn_probs_float.view(
|
||||||
|
batch_size * self.num_heads, max_num_global_attn_indices, seq_len
|
||||||
|
)
|
||||||
|
|
||||||
global_attn_probs = F.dropout(
|
global_attn_probs = F.dropout(
|
||||||
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
|
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
|
||||||
)
|
)
|
||||||
@@ -1109,6 +1130,7 @@ class LongformerAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
layer_head_mask=None,
|
||||||
is_index_masked=None,
|
is_index_masked=None,
|
||||||
is_index_global_attn=None,
|
is_index_global_attn=None,
|
||||||
is_global_attn=None,
|
is_global_attn=None,
|
||||||
@@ -1117,6 +1139,7 @@ class LongformerAttention(nn.Module):
|
|||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
is_index_masked=is_index_masked,
|
is_index_masked=is_index_masked,
|
||||||
is_index_global_attn=is_index_global_attn,
|
is_index_global_attn=is_index_global_attn,
|
||||||
is_global_attn=is_global_attn,
|
is_global_attn=is_global_attn,
|
||||||
@@ -1171,6 +1194,7 @@ class LongformerLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
layer_head_mask=None,
|
||||||
is_index_masked=None,
|
is_index_masked=None,
|
||||||
is_index_global_attn=None,
|
is_index_global_attn=None,
|
||||||
is_global_attn=None,
|
is_global_attn=None,
|
||||||
@@ -1179,6 +1203,7 @@ class LongformerLayer(nn.Module):
|
|||||||
self_attn_outputs = self.attention(
|
self_attn_outputs = self.attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=layer_head_mask,
|
||||||
is_index_masked=is_index_masked,
|
is_index_masked=is_index_masked,
|
||||||
is_index_global_attn=is_index_global_attn,
|
is_index_global_attn=is_index_global_attn,
|
||||||
is_global_attn=is_global_attn,
|
is_global_attn=is_global_attn,
|
||||||
@@ -1209,6 +1234,7 @@ class LongformerEncoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
@@ -1222,7 +1248,12 @@ class LongformerEncoder(nn.Module):
|
|||||||
all_attentions = () if output_attentions else None # All local attentions.
|
all_attentions = () if output_attentions else None # All local attentions.
|
||||||
all_global_attentions = () if (output_attentions and is_global_attn) else None
|
all_global_attentions = () if (output_attentions and is_global_attn) else None
|
||||||
|
|
||||||
for i, layer_module in enumerate(self.layer):
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layer)
|
||||||
|
), f"The head_mask should be specified for {len(self.layer)} layers, but it is for {head_mask.size()[0]}."
|
||||||
|
for idx, layer_module in enumerate(self.layer):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
@@ -1238,6 +1269,7 @@ class LongformerEncoder(nn.Module):
|
|||||||
create_custom_forward(layer_module),
|
create_custom_forward(layer_module),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
head_mask[idx] if head_mask is not None else None,
|
||||||
is_index_masked,
|
is_index_masked,
|
||||||
is_index_global_attn,
|
is_index_global_attn,
|
||||||
)
|
)
|
||||||
@@ -1245,6 +1277,7 @@ class LongformerEncoder(nn.Module):
|
|||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
layer_head_mask=head_mask[idx] if head_mask is not None else None,
|
||||||
is_index_masked=is_index_masked,
|
is_index_masked=is_index_masked,
|
||||||
is_index_global_attn=is_index_global_attn,
|
is_index_global_attn=is_index_global_attn,
|
||||||
is_global_attn=is_global_attn,
|
is_global_attn=is_global_attn,
|
||||||
@@ -1386,6 +1419,18 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
|
|||||||
- 0 for local attention (a sliding window attention),
|
- 0 for local attention (a sliding window attention),
|
||||||
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
||||||
|
|
||||||
|
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the heas is **masked**.
|
||||||
|
|
||||||
|
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
||||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
||||||
1]``:
|
1]``:
|
||||||
@@ -1534,6 +1579,7 @@ class LongformerModel(LongformerPreTrainedModel):
|
|||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
global_attention_mask=None,
|
global_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1617,6 +1663,7 @@ class LongformerModel(LongformerPreTrainedModel):
|
|||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -1667,6 +1714,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
|
|||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
global_attention_mask=None,
|
global_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1708,6 +1756,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
|
|||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
global_attention_mask=global_attention_mask,
|
global_attention_mask=global_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1767,6 +1816,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
|
|||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
global_attention_mask=None,
|
global_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1793,6 +1843,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
|
|||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
global_attention_mask=global_attention_mask,
|
global_attention_mask=global_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1871,6 +1922,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
|
|||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
global_attention_mask=None,
|
global_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1932,6 +1984,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
|
|||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
global_attention_mask=global_attention_mask,
|
global_attention_mask=global_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -2011,6 +2064,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
|
|||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
global_attention_mask=None,
|
global_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -2030,6 +2084,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
|
|||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
global_attention_mask=global_attention_mask,
|
global_attention_mask=global_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -2101,6 +2156,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel):
|
|||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
global_attention_mask=None,
|
global_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -2150,6 +2206,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel):
|
|||||||
token_type_ids=flat_token_type_ids,
|
token_type_ids=flat_token_type_ids,
|
||||||
attention_mask=flat_attention_mask,
|
attention_mask=flat_attention_mask,
|
||||||
global_attention_mask=flat_global_attention_mask,
|
global_attention_mask=flat_global_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
inputs_embeds=flat_inputs_embeds,
|
inputs_embeds=flat_inputs_embeds,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
|||||||
@@ -473,7 +473,6 @@ class ModelTesterMixin:
|
|||||||
arg_names = [*signature.parameters.keys()]
|
arg_names = [*signature.parameters.keys()]
|
||||||
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
|
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
|
||||||
inputs["decoder_head_mask"] = head_mask
|
inputs["decoder_head_mask"] = head_mask
|
||||||
|
|
||||||
outputs = model(**inputs, return_dict=True)
|
outputs = model(**inputs, return_dict=True)
|
||||||
|
|
||||||
# Test that we can get a gradient back for importance score computation
|
# Test that we can get a gradient back for importance score computation
|
||||||
|
|||||||
@@ -49,16 +49,24 @@ def prepare_led_inputs_dict(
|
|||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
):
|
):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = input_ids.ne(config.pad_token_id)
|
attention_mask = input_ids.ne(config.pad_token_id)
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||||
|
if decoder_head_mask is None:
|
||||||
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -160,9 +168,10 @@ class LEDModelTester:
|
|||||||
model = LEDModel(config=config).get_decoder().to(torch_device).eval()
|
model = LEDModel(config=config).get_decoder().to(torch_device).eval()
|
||||||
input_ids = inputs_dict["input_ids"]
|
input_ids = inputs_dict["input_ids"]
|
||||||
attention_mask = inputs_dict["attention_mask"]
|
attention_mask = inputs_dict["attention_mask"]
|
||||||
|
head_mask = inputs_dict["head_mask"]
|
||||||
|
|
||||||
# first forward pass
|
# first forward pass
|
||||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||||
|
|
||||||
output, past_key_values = outputs.to_tuple()
|
output, past_key_values = outputs.to_tuple()
|
||||||
|
|
||||||
@@ -258,7 +267,6 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
all_generative_model_classes = (LEDForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (LEDForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
@@ -273,7 +273,6 @@ class LongformerModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
test_pruning = False # pruning is not supported
|
test_pruning = False # pruning is not supported
|
||||||
test_headmasking = False # head masking is not supported
|
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
|
|||||||
Reference in New Issue
Block a user