[All Seq2Seq model + CLM models that can be used with EncoderDecoder] Add cross-attention weights to outputs (#8071)
* Output cross-attention with decoder attention output * Update src/transformers/modeling_bert.py * add cross-attention for t5 and bart as well * fix tests * correct typo in docs * add sylvains and sams comments * correct typo Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -65,12 +65,34 @@ BaseModelOutputWithPooling
|
||||
:members:
|
||||
|
||||
|
||||
BaseModelOutputWithCrossAttentions
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithCrossAttentions
|
||||
:members:
|
||||
|
||||
|
||||
BaseModelOutputWithPoolingAndCrossAttentions
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions
|
||||
:members:
|
||||
|
||||
|
||||
BaseModelOutputWithPast
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPast
|
||||
:members:
|
||||
|
||||
|
||||
BaseModelOutputWithPastAndCrossAttentions
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions
|
||||
:members:
|
||||
|
||||
|
||||
Seq2SeqModelOutput
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -85,6 +107,20 @@ CausalLMOutput
|
||||
:members:
|
||||
|
||||
|
||||
CausalLMOutputWithCrossAttentions
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithCrossAttentions
|
||||
:members:
|
||||
|
||||
|
||||
CausalLMOutputWithPastAndCrossAttentions
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithPastAndCrossAttentions
|
||||
:members:
|
||||
|
||||
|
||||
CausalLMOutputWithPast
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from .file_utils import (
|
||||
)
|
||||
from .modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPast,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
Seq2SeqQuestionAnsweringModelOutput,
|
||||
@@ -451,11 +451,12 @@ class DecoderLayer(nn.Module):
|
||||
assert self.encoder_attn.cache_key != self.self_attn.cache_key
|
||||
if self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
x, _ = self.encoder_attn(
|
||||
x, cross_attn_weights = self.encoder_attn(
|
||||
query=x,
|
||||
key=encoder_hidden_states,
|
||||
key_padding_mask=encoder_attn_mask,
|
||||
layer_state=layer_state, # mutates layer state
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
@@ -477,7 +478,8 @@ class DecoderLayer(nn.Module):
|
||||
x,
|
||||
self_attn_weights,
|
||||
layer_state,
|
||||
) # just self_attn weights for now, following t5, layer_state = cache for decoding
|
||||
cross_attn_weights,
|
||||
) # layer_state = cache for decoding
|
||||
|
||||
|
||||
class BartDecoder(nn.Module):
|
||||
@@ -590,6 +592,7 @@ class BartDecoder(nn.Module):
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions else None
|
||||
next_decoder_cache: List[Dict] = []
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
@@ -601,7 +604,7 @@ class BartDecoder(nn.Module):
|
||||
|
||||
layer_state = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
x, layer_self_attn, layer_past = decoder_layer(
|
||||
x, layer_self_attn, layer_past, layer_cross_attn = decoder_layer(
|
||||
x,
|
||||
encoder_hidden_states,
|
||||
encoder_attn_mask=encoder_padding_mask,
|
||||
@@ -616,6 +619,7 @@ class BartDecoder(nn.Module):
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_self_attn,)
|
||||
all_cross_attentions += (layer_cross_attn,)
|
||||
|
||||
if self.layer_norm: # if config.add_final_layer_norm (mBART)
|
||||
x = self.layer_norm(x)
|
||||
@@ -628,9 +632,15 @@ class BartDecoder(nn.Module):
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns
|
||||
return tuple(
|
||||
v for v in [x, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=x,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -934,6 +944,7 @@ class BartModel(PretrainedBartModel):
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
@@ -1078,6 +1089,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
past_key_values=outputs.past_key_values,
|
||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||
decoder_attentions=outputs.decoder_attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
@@ -1207,6 +1219,7 @@ class BartForSequenceClassification(PretrainedBartModel):
|
||||
past_key_values=outputs.past_key_values,
|
||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||
decoder_attentions=outputs.decoder_attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
@@ -1317,6 +1330,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
|
||||
past_key_values=outputs.past_key_values,
|
||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||
decoder_attentions=outputs.decoder_attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
|
||||
@@ -37,9 +37,9 @@ from .file_utils import (
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
CausalLMOutput,
|
||||
BaseModelOutputWithCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
MaskedLMOutput,
|
||||
MultipleChoiceModelOutput,
|
||||
NextSentencePredictorOutput,
|
||||
@@ -449,7 +449,8 @@ class BertEncoder(nn.Module):
|
||||
return_dict=False,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
@@ -483,15 +484,24 @@ class BertEncoder(nn.Module):
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -752,7 +762,7 @@ class BertModel(BertPreTrainedModel):
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="bert-base-uncased",
|
||||
output_type=BaseModelOutputWithPooling,
|
||||
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
@@ -843,11 +853,12 @@ class BertModel(BertPreTrainedModel):
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -984,7 +995,7 @@ class BertLMHeadModel(BertPreTrainedModel):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -1063,11 +1074,12 @@ class BertLMHeadModel(BertPreTrainedModel):
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||
|
||||
return CausalLMOutput(
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=lm_loss,
|
||||
logits=prediction_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
||||
|
||||
@@ -28,7 +28,7 @@ from .file_utils import (
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .modeling_bert import BertEncoder
|
||||
from .modeling_outputs import BaseModelOutput, CausalLMOutput
|
||||
from .modeling_outputs import BaseModelOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .utils import logging
|
||||
|
||||
@@ -297,7 +297,7 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder",
|
||||
output_type=BaseModelOutput,
|
||||
output_type=BaseModelOutputWithCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
@@ -381,10 +381,11 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
||||
if not return_dict:
|
||||
return (sequence_output,) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutput(
|
||||
return BaseModelOutputWithCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -422,7 +423,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
||||
return self.lm_head.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -499,11 +500,12 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
||||
output = (prediction_scores,) + outputs[1:]
|
||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||
|
||||
return CausalLMOutput(
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=lm_loss,
|
||||
logits=prediction_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
||||
|
||||
@@ -34,7 +34,7 @@ from .file_utils import (
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithCrossAttentions,
|
||||
MaskedLMOutput,
|
||||
MultipleChoiceModelOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
@@ -445,7 +445,8 @@ class ElectraEncoder(nn.Module):
|
||||
return_dict=False,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
@@ -479,15 +480,24 @@ class ElectraEncoder(nn.Module):
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -697,7 +707,7 @@ class ElectraModel(ElectraPreTrainedModel):
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="google/electra-small-discriminator",
|
||||
output_type=BaseModelOutput,
|
||||
output_type=BaseModelOutputWithCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
|
||||
@@ -426,6 +426,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
|
||||
@@ -46,7 +46,12 @@ from .file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput
|
||||
from .modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
)
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .utils import logging
|
||||
|
||||
@@ -543,11 +548,12 @@ class DecoderLayer(nn.Module):
|
||||
# Cross attention
|
||||
residual = x
|
||||
assert self.encoder_attn.cache_key != self.self_attn.cache_key
|
||||
x, _ = self.encoder_attn(
|
||||
x, cross_attn_weights = self.encoder_attn(
|
||||
query=x,
|
||||
key=encoder_hidden_states,
|
||||
key_padding_mask=encoder_attn_mask,
|
||||
layer_state=layer_state, # mutates layer state
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
@@ -565,7 +571,8 @@ class DecoderLayer(nn.Module):
|
||||
x,
|
||||
self_attn_weights,
|
||||
layer_state,
|
||||
) # just self_attn weights for now, following t5, layer_state = cache for decoding
|
||||
cross_attn_weights,
|
||||
) # layer_state = cache for decoding
|
||||
|
||||
|
||||
class FSMTDecoder(nn.Module):
|
||||
@@ -669,6 +676,7 @@ class FSMTDecoder(nn.Module):
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attns = () if output_attentions else None
|
||||
next_decoder_cache = []
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
@@ -680,7 +688,7 @@ class FSMTDecoder(nn.Module):
|
||||
|
||||
layer_state = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
x, layer_self_attn, layer_past = decoder_layer(
|
||||
x, layer_self_attn, layer_past, layer_cross_attn = decoder_layer(
|
||||
x,
|
||||
encoder_hidden_states,
|
||||
encoder_attn_mask=encoder_padding_mask,
|
||||
@@ -695,6 +703,7 @@ class FSMTDecoder(nn.Module):
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_self_attn,)
|
||||
all_cross_attns += (layer_cross_attn,)
|
||||
|
||||
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||||
if output_hidden_states:
|
||||
@@ -707,9 +716,15 @@ class FSMTDecoder(nn.Module):
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns
|
||||
return tuple(
|
||||
v for v in [x, next_cache, all_hidden_states, all_self_attns, all_cross_attns] if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=x,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
cross_attentions=all_cross_attns,
|
||||
)
|
||||
|
||||
|
||||
@@ -903,7 +918,7 @@ class FSMTModel(PretrainedFSMTModel):
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="facebook/wmt19-ru-en",
|
||||
output_type=BaseModelOutputWithPast,
|
||||
output_type=Seq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
@@ -989,6 +1004,7 @@ class FSMTModel(PretrainedFSMTModel):
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
@@ -1101,6 +1117,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
||||
past_key_values=outputs.past_key_values,
|
||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||
decoder_attentions=outputs.decoder_attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
|
||||
@@ -33,7 +33,11 @@ from .file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from .modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithPastAndCrossAttentions,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from .modeling_utils import (
|
||||
Conv1D,
|
||||
PreTrainedModel,
|
||||
@@ -311,14 +315,14 @@ class Block(nn.Module):
|
||||
attn_output = cross_attn_outputs[0]
|
||||
# residual connection
|
||||
hidden_states = hidden_states + attn_output
|
||||
outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
|
||||
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
||||
|
||||
feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
|
||||
# residual connection
|
||||
hidden_states = hidden_states + feed_forward_hidden_states
|
||||
|
||||
outputs = [hidden_states] + outputs
|
||||
return outputs # hidden_states, present, (cross_attentions, attentions)
|
||||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
||||
|
||||
|
||||
class GPT2PreTrainedModel(PreTrainedModel):
|
||||
@@ -506,7 +510,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="gpt2",
|
||||
output_type=BaseModelOutputWithPast,
|
||||
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
@@ -618,7 +622,8 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
if output_hidden_states:
|
||||
@@ -659,7 +664,9 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
presents = presents + (present,)
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (outputs[2],)
|
||||
all_self_attentions = all_self_attentions + (outputs[2],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3],)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
@@ -669,13 +676,14 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -727,7 +735,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="gpt2",
|
||||
output_type=CausalLMOutputWithPast,
|
||||
output_type=CausalLMOutputWithPastAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
@@ -795,12 +803,13 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
return CausalLMOutputWithPastAndCrossAttentions(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
cross_attentions=transformer_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,12 @@ from torch.nn import CrossEntropyLoss
|
||||
from .activations import ACT2FN
|
||||
from .configuration_layoutlm import LayoutLMConfig
|
||||
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, TokenClassifierOutput
|
||||
from .modeling_outputs import (
|
||||
BaseModelOutputWithCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
MaskedLMOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from .modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
@@ -374,7 +379,8 @@ class LayoutLMEncoder(nn.Module):
|
||||
return_dict=False,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
@@ -408,15 +414,24 @@ class LayoutLMEncoder(nn.Module):
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -611,7 +626,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="layoutlm-base-uncased",
|
||||
output_type=BaseModelOutputWithPooling,
|
||||
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
@@ -716,11 +731,12 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -99,6 +99,120 @@ class BaseModelOutputWithPast(ModelOutput):
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelOutputWithCrossAttentions(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs, with potential hidden states and attentions.
|
||||
|
||||
Args:
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||
weighted average in the cross-attention heads.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs that also contains a pooling of the last hidden states.
|
||||
|
||||
Args:
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
|
||||
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
|
||||
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
|
||||
prediction (classification) objective during pretraining.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||
weighted average in the cross-attention heads.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor
|
||||
pooler_output: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
|
||||
|
||||
Args:
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
|
||||
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
|
||||
1, hidden_size)` is output.
|
||||
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
|
||||
batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
||||
:obj:`past_key_values` input) to speed up sequential decoding.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||
weighted average in the cross-attention heads.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Seq2SeqModelOutput(ModelOutput):
|
||||
"""
|
||||
@@ -128,6 +242,12 @@ class Seq2SeqModelOutput(ModelOutput):
|
||||
|
||||
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads.
|
||||
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||
weighted average in the cross-attention heads.
|
||||
encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
||||
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
@@ -147,6 +267,7 @@ class Seq2SeqModelOutput(ModelOutput):
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
@@ -217,6 +338,85 @@ class CausalLMOutputWithPast(ModelOutput):
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CausalLMOutputWithCrossAttentions(ModelOutput):
|
||||
"""
|
||||
Base class for causal language model (or autoregressive) outputs.
|
||||
|
||||
Args:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Cross attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
cross-attention heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor]
|
||||
logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CausalLMOutputWithPastAndCrossAttentions(ModelOutput):
|
||||
"""
|
||||
Base class for causal language model (or autoregressive) outputs.
|
||||
|
||||
Args:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
|
||||
batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
||||
:obj:`past_key_values` input) to speed up sequential decoding.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Cross attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
cross-attention heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceClassifierOutputWithPast(ModelOutput):
|
||||
"""
|
||||
@@ -309,6 +509,12 @@ class Seq2SeqLMOutput(ModelOutput):
|
||||
|
||||
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads.
|
||||
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||
weighted average in the cross-attention heads.
|
||||
encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
||||
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
@@ -329,6 +535,7 @@ class Seq2SeqLMOutput(ModelOutput):
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
@@ -420,6 +627,12 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
|
||||
|
||||
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads.
|
||||
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||
weighted average in the cross-attention heads.
|
||||
encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
||||
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
@@ -440,6 +653,7 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
@@ -566,6 +780,12 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
||||
|
||||
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads.
|
||||
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||
weighted average in the cross-attention heads.
|
||||
encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
||||
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
@@ -587,6 +807,7 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
@@ -261,7 +262,7 @@ class ProphetNetSeq2SeqLMOutput(ModelOutput):
|
||||
|
||||
Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
|
||||
weighted average in the self-attention heads.
|
||||
decoder_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_attn_heads,
|
||||
encoder_sequence_length, decoder_sequence_length)`.
|
||||
|
||||
@@ -288,11 +289,19 @@ class ProphetNetSeq2SeqLMOutput(ModelOutput):
|
||||
decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
@property
|
||||
def decoder_cross_attentions(self):
|
||||
warnings.warn(
|
||||
"`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self.cross_attentions
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProphetNetSeq2SeqModelOutput(ModelOutput):
|
||||
@@ -337,7 +346,7 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput):
|
||||
|
||||
Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
|
||||
weighted average in the
|
||||
decoder_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_attn_heads,
|
||||
encoder_sequence_length, decoder_sequence_length)`.
|
||||
|
||||
@@ -365,11 +374,19 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput):
|
||||
decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
@property
|
||||
def decoder_cross_attentions(self):
|
||||
warnings.warn(
|
||||
"`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self.cross_attentions
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProphetNetDecoderModelOutput(ModelOutput):
|
||||
@@ -1651,7 +1668,7 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
|
||||
decoder_ngram_hidden_states=decoder_outputs.hidden_states_ngram,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
decoder_ngram_attentions=decoder_outputs.ngram_attentions,
|
||||
decoder_cross_attentions=decoder_outputs.cross_attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
@@ -1766,7 +1783,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
|
||||
decoder_ngram_hidden_states=outputs.decoder_ngram_hidden_states,
|
||||
decoder_attentions=outputs.decoder_attentions,
|
||||
decoder_ngram_attentions=outputs.decoder_ngram_attentions,
|
||||
decoder_cross_attentions=outputs.decoder_cross_attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
@@ -1986,6 +2003,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
||||
hidden_states_ngram=outputs.hidden_states_ngram,
|
||||
attentions=outputs.attentions,
|
||||
ngram_attentions=outputs.ngram_attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def _compute_loss(self, logits, labels):
|
||||
|
||||
@@ -31,9 +31,9 @@ from .file_utils import (
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
CausalLMOutput,
|
||||
BaseModelOutputWithCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
MaskedLMOutput,
|
||||
MultipleChoiceModelOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
@@ -393,7 +393,8 @@ class RobertaEncoder(nn.Module):
|
||||
return_dict=False,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
@@ -427,15 +428,24 @@ class RobertaEncoder(nn.Module):
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -599,7 +609,7 @@ class RobertaModel(RobertaPreTrainedModel):
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="roberta-base",
|
||||
output_type=BaseModelOutputWithPooling,
|
||||
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
# Copied from transformers.modeling_bert.BertModel.forward
|
||||
@@ -689,11 +699,12 @@ class RobertaModel(RobertaPreTrainedModel):
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -719,7 +730,7 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
||||
return self.lm_head.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -799,11 +810,12 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||
|
||||
return CausalLMOutput(
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=lm_loss,
|
||||
logits=prediction_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
||||
|
||||
@@ -33,7 +33,12 @@ from .file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput
|
||||
from .modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
)
|
||||
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from .utils import logging
|
||||
|
||||
@@ -503,6 +508,7 @@ class T5Block(nn.Module):
|
||||
past_key_value=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
return_dict=False,
|
||||
):
|
||||
|
||||
if past_key_value is not None:
|
||||
@@ -533,7 +539,8 @@ class T5Block(nn.Module):
|
||||
hidden_states, present_key_value_state = self_attention_outputs[:2]
|
||||
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
|
||||
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
|
||||
if do_cross_attention:
|
||||
# the actual query length is unknown for cross attention
|
||||
# if using past key value states. Need to inject it here
|
||||
if present_key_value_state is not None:
|
||||
@@ -564,7 +571,6 @@ class T5Block(nn.Module):
|
||||
hidden_states = self.layer[-1](hidden_states)
|
||||
outputs = (hidden_states,)
|
||||
|
||||
# Add attentions if we output them
|
||||
outputs = outputs + (present_key_value_state,) + attention_outputs
|
||||
return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
|
||||
@@ -743,6 +749,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
present_key_value_states = () if use_cache else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
|
||||
position_bias = None
|
||||
encoder_decoder_position_bias = None
|
||||
|
||||
@@ -779,7 +786,9 @@ class T5Stack(T5PreTrainedModel):
|
||||
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now
|
||||
all_attentions = all_attentions + (layer_outputs[2],)
|
||||
if self.is_decoder:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[4 if i == 0 else 3],)
|
||||
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
@@ -791,14 +800,21 @@ class T5Stack(T5PreTrainedModel):
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, present_key_value_states, all_hidden_states, all_attentions]
|
||||
for v in [
|
||||
hidden_states,
|
||||
present_key_value_states,
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=present_key_value_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -1038,6 +1054,7 @@ class T5Model(T5PreTrainedModel):
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
@@ -1227,6 +1244,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
|
||||
@@ -253,9 +253,7 @@ class ModelTesterMixin:
|
||||
out_len = len(outputs)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
correct_outlen = (
|
||||
self.model_tester.base_model_out_len if hasattr(self.model_tester, "base_model_out_len") else 4
|
||||
)
|
||||
correct_outlen = 5
|
||||
|
||||
# loss is at first position
|
||||
if "labels" in inputs_dict:
|
||||
@@ -266,6 +264,7 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
# decoder attentions
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
@@ -274,6 +273,19 @@ class ModelTesterMixin:
|
||||
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||
)
|
||||
|
||||
# cross attentions
|
||||
cross_attentions = outputs.cross_attentions
|
||||
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.num_attention_heads,
|
||||
decoder_seq_length,
|
||||
encoder_key_length,
|
||||
],
|
||||
)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
|
||||
@@ -292,6 +292,62 @@ class EncoderDecoderMixin:
|
||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
labels,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(encoder_attentions[0].shape[-3:]),
|
||||
[config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]],
|
||||
)
|
||||
|
||||
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
|
||||
num_decoder_layers = (
|
||||
decoder_config.num_decoder_layers
|
||||
if hasattr(decoder_config, "num_decoder_layers")
|
||||
else decoder_config.num_hidden_layers
|
||||
)
|
||||
self.assertEqual(len(decoder_attentions), num_decoder_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]],
|
||||
)
|
||||
|
||||
cross_attentions = outputs_encoder_decoder["cross_attentions"]
|
||||
self.assertEqual(len(cross_attentions), num_decoder_layers)
|
||||
|
||||
cross_attention_input_seq_len = input_ids.shape[-1] * (
|
||||
1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
[decoder_config.num_attention_heads, cross_attention_input_seq_len, decoder_input_ids.shape[-1]],
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
@@ -413,6 +469,10 @@ class EncoderDecoderMixin:
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_labels(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_output_attentions(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_generate(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
||||
@@ -916,6 +916,116 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
|
||||
|
||||
# methods overwrite method in `test_modeling_common.py`
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
chunk_length = getattr(self.model_tester, "chunk_length", None)
|
||||
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
|
||||
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
correct_outlen = 7
|
||||
|
||||
# loss is at first position
|
||||
if "labels" in inputs_dict:
|
||||
correct_outlen += 1 # loss is added to beginning
|
||||
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
# decoder attentions
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||
)
|
||||
|
||||
# cross attentions
|
||||
cross_attentions = outputs.cross_attentions
|
||||
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.num_attention_heads,
|
||||
(self.model_tester.ngram + 1) * decoder_seq_length,
|
||||
encoder_key_length,
|
||||
],
|
||||
)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
elif self.is_encoder_decoder:
|
||||
added_hidden_states = 2
|
||||
else:
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user