From 00031785a846c5a9fe7541a717a436840b609084 Mon Sep 17 00:00:00 2001 From: demSd Date: Thu, 4 Feb 2021 09:56:12 +0100 Subject: [PATCH] BartForCausalLM analogs to `ProphetNetForCausalLM` (#9128) * initiliaze bart4causalLM * create BartDecoderWrapper, setters/getters * delete spaces * forward and additional methods * update cache function, loss function, remove ngram* params in data class. * add bartcausallm, bartdecoder testing * correct bart for causal lm * remove at * add mbart as well * up * fix typo * up * correct * add pegasusforcausallm * add blenderbotforcausallm * add blenderbotsmallforcausallm * add marianforcausallm * add test for MarianForCausalLM * add Pegasus test * add BlenderbotSmall test * add blenderbot test * fix a fail * fix an import fail * a fix * fix * Update modeling_pegasus.py * fix models * fix inputs_embeds setting getter * adapt tests * correct repo utils check * finish test improvement * fix tf models as well * make style * make fix-copies * fix copies * run all tests * last changes * fix all tests Co-authored-by: Patrick von Platen --- docs/source/model_doc/bart.rst | 6 + docs/source/model_doc/blenderbot.rst | 7 + docs/source/model_doc/blenderbot_small.rst | 7 + docs/source/model_doc/marian.rst | 7 + docs/source/model_doc/mbart.rst | 7 + docs/source/model_doc/pegasus.rst | 7 + src/transformers/__init__.py | 18 +- src/transformers/models/auto/modeling_auto.py | 20 +- src/transformers/models/bart/__init__.py | 2 + src/transformers/models/bart/modeling_bart.py | 259 +++++++++++- .../models/bart/modeling_tf_bart.py | 2 +- .../models/blenderbot/__init__.py | 2 + .../models/blenderbot/modeling_blenderbot.py | 277 ++++++++++-- .../blenderbot/modeling_tf_blenderbot.py | 2 +- .../models/blenderbot_small/__init__.py | 2 + .../modeling_blenderbot_small.py | 260 +++++++++++- .../modeling_tf_blenderbot_small.py | 2 +- src/transformers/models/marian/__init__.py | 2 + .../models/marian/modeling_marian.py | 260 +++++++++++- .../models/marian/modeling_tf_marian.py | 2 +- src/transformers/models/mbart/__init__.py | 2 + .../models/mbart/modeling_mbart.py | 262 +++++++++++- .../models/mbart/modeling_tf_mbart.py | 2 +- src/transformers/models/pegasus/__init__.py | 2 + .../models/pegasus/modeling_pegasus.py | 261 +++++++++++- .../models/pegasus/modeling_tf_pegasus.py | 2 +- .../models/prophetnet/modeling_prophetnet.py | 2 +- src/transformers/utils/dummy_pt_objects.py | 30 ++ ...ng_{{cookiecutter.lowercase_modelname}}.py | 398 +++++++++++++++++- ...ng_{{cookiecutter.lowercase_modelname}}.py | 218 +++++++++- tests/test_modeling_bart.py | 244 ++++++++++- tests/test_modeling_blenderbot.py | 228 +++++++++- tests/test_modeling_blenderbot_small.py | 222 +++++++++- tests/test_modeling_encoder_decoder.py | 56 +++ tests/test_modeling_marian.py | 228 +++++++++- tests/test_modeling_mbart.py | 221 +++++++++- tests/test_modeling_pegasus.py | 223 +++++++++- utils/check_repo.py | 20 +- 38 files changed, 3595 insertions(+), 177 deletions(-) diff --git a/docs/source/model_doc/bart.rst b/docs/source/model_doc/bart.rst index 4bd4b94253..3e754f67c7 100644 --- a/docs/source/model_doc/bart.rst +++ b/docs/source/model_doc/bart.rst @@ -130,6 +130,12 @@ BartForQuestionAnswering .. autoclass:: transformers.BartForQuestionAnswering :members: forward +BartForCausalLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BartForCausalLM + :members: forward + TFBartModel diff --git a/docs/source/model_doc/blenderbot.rst b/docs/source/model_doc/blenderbot.rst index 43b4fb7a93..4a13199d61 100644 --- a/docs/source/model_doc/blenderbot.rst +++ b/docs/source/model_doc/blenderbot.rst @@ -98,6 +98,13 @@ See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward` :members: forward +BlenderbotForCausalLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BlenderbotForCausalLM + :members: forward + + TFBlenderbotModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/blenderbot_small.rst b/docs/source/model_doc/blenderbot_small.rst index ede238d47b..9eb2a5c0ee 100644 --- a/docs/source/model_doc/blenderbot_small.rst +++ b/docs/source/model_doc/blenderbot_small.rst @@ -70,6 +70,13 @@ BlenderbotSmallForConditionalGeneration :members: forward +BlenderbotSmallForCausalLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BlenderbotSmallForCausalLM + :members: forward + + TFBlenderbotSmallModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/marian.rst b/docs/source/model_doc/marian.rst index fba8e7d876..18d515a869 100644 --- a/docs/source/model_doc/marian.rst +++ b/docs/source/model_doc/marian.rst @@ -193,6 +193,13 @@ MarianMTModel :members: forward +MarianForCausalLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MarianForCausalLM + :members: forward + + TFMarianModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/mbart.rst b/docs/source/model_doc/mbart.rst index 0a43d0f12c..0a9c1cff90 100644 --- a/docs/source/model_doc/mbart.rst +++ b/docs/source/model_doc/mbart.rst @@ -124,6 +124,13 @@ MBartForSequenceClassification .. autoclass:: transformers.MBartForSequenceClassification +MBartForCausalLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MBartForCausalLM + :members: forward + + TFMBartModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/pegasus.rst b/docs/source/model_doc/pegasus.rst index b27f7d074b..ad582230e9 100644 --- a/docs/source/model_doc/pegasus.rst +++ b/docs/source/model_doc/pegasus.rst @@ -131,6 +131,13 @@ PegasusForConditionalGeneration :members: forward +PegasusForCausalLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.PegasusForCausalLM + :members: forward + + TFPegasusModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 259363d561..a8801400a5 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -431,6 +431,7 @@ if is_torch_available(): _import_structure["models.bart"].extend( [ "BART_PRETRAINED_MODEL_ARCHIVE_LIST", + "BartForCausalLM", "BartForConditionalGeneration", "BartForQuestionAnswering", "BartForSequenceClassification", @@ -468,6 +469,7 @@ if is_torch_available(): "BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST", "BlenderbotForConditionalGeneration", "BlenderbotModel", + "BlenderbotForCausalLM", ] ) _import_structure["models.blenderbot_small"].extend( @@ -475,6 +477,7 @@ if is_torch_available(): "BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST", "BlenderbotSmallForConditionalGeneration", "BlenderbotSmallModel", + "BlenderbotSmallForCausalLM", ] ) _import_structure["models.camembert"].extend( @@ -628,9 +631,10 @@ if is_torch_available(): "LxmertXLayer", ] ) - _import_structure["models.marian"].extend(["MarianModel", "MarianMTModel"]) + _import_structure["models.marian"].extend(["MarianModel", "MarianMTModel", "MarianForCausalLM"]) _import_structure["models.mbart"].extend( [ + "MBartForCausalLM", "MBartForConditionalGeneration", "MBartForQuestionAnswering", "MBartForSequenceClassification", @@ -679,7 +683,9 @@ if is_torch_available(): "load_tf_weights_in_openai_gpt", ] ) - _import_structure["models.pegasus"].extend(["PegasusForConditionalGeneration", "PegasusModel"]) + _import_structure["models.pegasus"].extend( + ["PegasusForConditionalGeneration", "PegasusModel", "PegasusForCausalLM"] + ) _import_structure["models.prophetnet"].extend( [ "PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1517,6 +1523,7 @@ if TYPE_CHECKING: ) from .models.bart import ( BART_PRETRAINED_MODEL_ARCHIVE_LIST, + BartForCausalLM, BartForConditionalGeneration, BartForQuestionAnswering, BartForSequenceClassification, @@ -1546,11 +1553,13 @@ if TYPE_CHECKING: ) from .models.blenderbot import ( BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, + BlenderbotForCausalLM, BlenderbotForConditionalGeneration, BlenderbotModel, ) from .models.blenderbot_small import ( BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST, + BlenderbotSmallForCausalLM, BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel, ) @@ -1691,8 +1700,9 @@ if TYPE_CHECKING: LxmertVisualFeatureEncoder, LxmertXLayer, ) - from .models.marian import MarianModel, MarianMTModel + from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel from .models.mbart import ( + MBartForCausalLM, MBartForConditionalGeneration, MBartForQuestionAnswering, MBartForSequenceClassification, @@ -1734,7 +1744,7 @@ if TYPE_CHECKING: OpenAIGPTPreTrainedModel, load_tf_weights_in_openai_gpt, ) - from .models.pegasus import PegasusForConditionalGeneration, PegasusModel + from .models.pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel from .models.prophetnet import ( PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, ProphetNetDecoder, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 7ddf27863b..9954589538 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -33,6 +33,7 @@ from ..albert.modeling_albert import ( AlbertModel, ) from ..bart.modeling_bart import ( + BartForCausalLM, BartForConditionalGeneration, BartForQuestionAnswering, BartForSequenceClassification, @@ -50,8 +51,12 @@ from ..bert.modeling_bert import ( BertModel, ) from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder -from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration, BlenderbotModel -from ..blenderbot_small.modeling_blenderbot_small import BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel +from ..blenderbot.modeling_blenderbot import BlenderbotForCausalLM, BlenderbotForConditionalGeneration, BlenderbotModel +from ..blenderbot_small.modeling_blenderbot_small import ( + BlenderbotSmallForCausalLM, + BlenderbotSmallForConditionalGeneration, + BlenderbotSmallModel, +) from ..camembert.modeling_camembert import ( CamembertForCausalLM, CamembertForMaskedLM, @@ -138,8 +143,9 @@ from ..longformer.modeling_longformer import ( LongformerModel, ) from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel -from ..marian.modeling_marian import MarianModel, MarianMTModel +from ..marian.modeling_marian import MarianForCausalLM, MarianModel, MarianMTModel from ..mbart.modeling_mbart import ( + MBartForCausalLM, MBartForConditionalGeneration, MBartForQuestionAnswering, MBartForSequenceClassification, @@ -165,7 +171,7 @@ from ..mpnet.modeling_mpnet import ( ) from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel -from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration, PegasusModel +from ..pegasus.modeling_pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel from ..rag.modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function RagModel, @@ -425,6 +431,12 @@ MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( (BertGenerationConfig, BertGenerationDecoder), (XLMProphetNetConfig, XLMProphetNetForCausalLM), (ProphetNetConfig, ProphetNetForCausalLM), + (BartConfig, BartForCausalLM), + (MBartConfig, MBartForCausalLM), + (PegasusConfig, PegasusForCausalLM), + (MarianConfig, MarianForCausalLM), + (BlenderbotConfig, BlenderbotForCausalLM), + (BlenderbotSmallConfig, BlenderbotSmallForCausalLM), ] ) diff --git a/src/transformers/models/bart/__init__.py b/src/transformers/models/bart/__init__.py index 979ad98faf..1742b58bb9 100644 --- a/src/transformers/models/bart/__init__.py +++ b/src/transformers/models/bart/__init__.py @@ -31,6 +31,7 @@ if is_tokenizers_available(): if is_torch_available(): _import_structure["modeling_bart"] = [ "BART_PRETRAINED_MODEL_ARCHIVE_LIST", + "BartForCausalLM", "BartForConditionalGeneration", "BartForQuestionAnswering", "BartForSequenceClassification", @@ -53,6 +54,7 @@ if TYPE_CHECKING: if is_torch_available(): from .modeling_bart import ( BART_PRETRAINED_MODEL_ARCHIVE_LIST, + BartForCausalLM, BartForConditionalGeneration, BartForQuestionAnswering, BartForSequenceClassification, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index c8462276bf..3c944ecf19 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch BART model. """ - - +import copy import math import random import warnings @@ -37,6 +36,7 @@ from ...file_utils import ( from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, Seq2SeqQuestionAnsweringModelOutput, @@ -843,6 +843,30 @@ class BartDecoder(BartPretrainedModel): self.init_weights() + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(self.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + def forward( self, input_ids=None, @@ -945,19 +969,9 @@ class BartDecoder(BartPretrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) - - if attention_mask is not None and combined_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = combined_attention_mask + _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: @@ -975,7 +989,7 @@ class BartDecoder(BartPretrainedModel): # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - all_cross_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None # check if head_mask has a correct number of layers specified if desired @@ -1012,7 +1026,7 @@ class BartDecoder(BartPretrainedModel): layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, - combined_attention_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -1023,7 +1037,7 @@ class BartDecoder(BartPretrainedModel): layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), @@ -1039,7 +1053,9 @@ class BartDecoder(BartPretrainedModel): if output_attentions: all_self_attns += (layer_outputs[1],) - all_cross_attentions += (layer_outputs[2],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) # add hidden states from the last decoder layer if output_hidden_states: @@ -1571,3 +1587,208 @@ class BartForQuestionAnswering(BartPretrainedModel): encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, ) + + +class BartDecoderWrapper(BartPretrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the :class:`~transformers.EncoderDecoderModel` framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = BartDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class BartForCausalLM(BartPretrainedModel): + def __init__(self, config): + super().__init__(config) + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + self.model = BartDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.init_weights() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids`` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., + config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., + config.vocab_size]``. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + + Returns: + + Example:: + + >>> from transformers import BartTokenizer, BartForCausalLM + + >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') + >>> model = BartForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index a5a6b907c6..4d3cd25b9a 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -915,7 +915,7 @@ class TFBartDecoder(tf.keras.layers.Layer): tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] ) - if inputs["attention_mask"] is not None and input_shape[-1] > 1: + if inputs["attention_mask"] is not None: combined_attention_mask = combined_attention_mask + _expand_mask( inputs["attention_mask"], tgt_len=input_shape[-1] ) diff --git a/src/transformers/models/blenderbot/__init__.py b/src/transformers/models/blenderbot/__init__.py index b418d3e9b3..cd46ae5703 100644 --- a/src/transformers/models/blenderbot/__init__.py +++ b/src/transformers/models/blenderbot/__init__.py @@ -32,6 +32,7 @@ if is_torch_available(): "BlenderbotForConditionalGeneration", "BlenderbotModel", "BlenderbotPreTrainedModel", + "BlenderbotForCausalLM", ] @@ -46,6 +47,7 @@ if TYPE_CHECKING: if is_torch_available(): from .modeling_blenderbot import ( BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, + BlenderbotForCausalLM, BlenderbotForConditionalGeneration, BlenderbotModel, BlenderbotPreTrainedModel, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 52e62863eb..2c2007ffe6 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -15,6 +15,7 @@ """ PyTorch Blenderbot model. """ +import copy import math import os import random @@ -37,6 +38,7 @@ from ...file_utils import ( from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) @@ -805,13 +807,38 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): self.init_weights() + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(self.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + def forward( self, input_ids=None, attention_mask=None, - head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, encoder_head_mask=None, past_key_values=None, inputs_embeds=None, @@ -838,12 +865,6 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. - encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. @@ -855,6 +876,12 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention on hidden heads. Mask values selected in ``[0, 1]``: @@ -907,19 +934,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) - - if attention_mask is not None and combined_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = combined_attention_mask + _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: @@ -929,7 +946,6 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): # embed positions positions = self.embed_positions(input_shape, past_key_values_length) - # in constrast to Bart, Blenderbot applies layernorm on inputs_embeds hidden_states = inputs_embeds + positions hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) @@ -937,7 +953,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - all_cross_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None # check if head_mask has a correct number of layers specified if desired @@ -974,7 +990,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, - combined_attention_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -985,10 +1001,10 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, @@ -1001,7 +1017,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): if output_attentions: all_self_attns += (layer_outputs[1],) - all_cross_attentions += (layer_outputs[2],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) # add final layer norm hidden_states = self.layer_norm(hidden_states) @@ -1336,3 +1354,210 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Blenderbot +class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the :class:`~transformers.EncoderDecoderModel` framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = BlenderbotDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot +class BlenderbotForCausalLM(BlenderbotPreTrainedModel): + def __init__(self, config): + super().__init__(config) + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + self.model = BlenderbotDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.init_weights() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids`` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., + config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., + config.vocab_size]``. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + + Returns: + + Example:: + + >>> from transformers import BlenderbotTokenizer, BlenderbotForCausalLM + + >>> tokenizer = BlenderbotTokenizer.from_pretrained('facebook/bart-large') + >>> model = BlenderbotForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index 258591c7f7..1ef33bbc57 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -925,7 +925,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] ) - if inputs["attention_mask"] is not None and input_shape[-1] > 1: + if inputs["attention_mask"] is not None: combined_attention_mask = combined_attention_mask + _expand_mask( inputs["attention_mask"], tgt_len=input_shape[-1] ) diff --git a/src/transformers/models/blenderbot_small/__init__.py b/src/transformers/models/blenderbot_small/__init__.py index 2b936cd90a..2f60bc77c0 100644 --- a/src/transformers/models/blenderbot_small/__init__.py +++ b/src/transformers/models/blenderbot_small/__init__.py @@ -31,6 +31,7 @@ if is_torch_available(): "BlenderbotSmallForConditionalGeneration", "BlenderbotSmallModel", "BlenderbotSmallPreTrainedModel", + "BlenderbotSmallForCausalLM", ] if is_tf_available(): @@ -46,6 +47,7 @@ if TYPE_CHECKING: if is_torch_available(): from .modeling_blenderbot_small import ( BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST, + BlenderbotSmallForCausalLM, BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel, BlenderbotSmallPreTrainedModel, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 8d576b8b3a..13ec8464e7 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -15,6 +15,7 @@ """ PyTorch BlenderbotSmall model. """ +import copy import math import random from typing import Optional, Tuple @@ -35,6 +36,7 @@ from ...file_utils import ( from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) @@ -805,6 +807,31 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): self.init_weights() + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(self.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + def forward( self, input_ids=None, @@ -907,19 +934,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) - - if attention_mask is not None and combined_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = combined_attention_mask + _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: @@ -938,7 +955,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - all_cross_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None if head_mask is not None: @@ -974,7 +991,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, - combined_attention_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -985,7 +1002,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), @@ -1001,7 +1018,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): if output_attentions: all_self_attns += (layer_outputs[1],) - all_cross_attentions += (layer_outputs[2],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) # add hidden states from the last decoder layer if output_hidden_states: @@ -1310,3 +1329,210 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->BlenderbotSmall +class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the :class:`~transformers.EncoderDecoderModel` framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = BlenderbotSmallDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall +class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): + def __init__(self, config): + super().__init__(config) + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + self.model = BlenderbotSmallDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.init_weights() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids`` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., + config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., + config.vocab_size]``. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + + Returns: + + Example:: + + >>> from transformers import BlenderbotSmallTokenizer, BlenderbotSmallForCausalLM + + >>> tokenizer = BlenderbotSmallTokenizer.from_pretrained('facebook/bart-large') + >>> model = BlenderbotSmallForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index 2ef1f9c149..4382708d76 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -925,7 +925,7 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] ) - if inputs["attention_mask"] is not None and input_shape[-1] > 1: + if inputs["attention_mask"] is not None: combined_attention_mask = combined_attention_mask + _expand_mask( inputs["attention_mask"], tgt_len=input_shape[-1] ) diff --git a/src/transformers/models/marian/__init__.py b/src/transformers/models/marian/__init__.py index ab1c5e8130..34a35922c8 100644 --- a/src/transformers/models/marian/__init__.py +++ b/src/transformers/models/marian/__init__.py @@ -39,6 +39,7 @@ if is_torch_available(): "MarianModel", "MarianMTModel", "MarianPreTrainedModel", + "MarianForCausalLM", ] if is_tf_available(): @@ -54,6 +55,7 @@ if TYPE_CHECKING: if is_torch_available(): from .modeling_marian import ( MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST, + MarianForCausalLM, MarianModel, MarianMTModel, MarianPreTrainedModel, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 39f3c87b00..31711909b0 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -15,6 +15,7 @@ """PyTorch MarianMTModel model, ported from the Marian C++ repo.""" +import copy import math import random from typing import Optional, Tuple @@ -36,6 +37,7 @@ from ...file_utils import ( from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) @@ -809,6 +811,31 @@ class MarianDecoder(MarianPreTrainedModel): self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)]) self.init_weights() + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(self.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + def forward( self, input_ids=None, @@ -911,19 +938,9 @@ class MarianDecoder(MarianPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) - - if attention_mask is not None and combined_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = combined_attention_mask + _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: @@ -940,7 +957,7 @@ class MarianDecoder(MarianPreTrainedModel): # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - all_cross_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None # check if head_mask has a correct number of layers specified if desired @@ -977,7 +994,7 @@ class MarianDecoder(MarianPreTrainedModel): layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, - combined_attention_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -988,7 +1005,7 @@ class MarianDecoder(MarianPreTrainedModel): layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), @@ -1004,7 +1021,9 @@ class MarianDecoder(MarianPreTrainedModel): if output_attentions: all_self_attns += (layer_outputs[1],) - all_cross_attentions += (layer_outputs[2],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) # add hidden states from the last decoder layer if output_hidden_states: @@ -1321,3 +1340,210 @@ class MarianMTModel(MarianPreTrainedModel): tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Marian +class MarianDecoderWrapper(MarianPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the :class:`~transformers.EncoderDecoderModel` framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = MarianDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian +class MarianForCausalLM(MarianPreTrainedModel): + def __init__(self, config): + super().__init__(config) + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + self.model = MarianDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.init_weights() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids`` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., + config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., + config.vocab_size]``. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + + Returns: + + Example:: + + >>> from transformers import MarianTokenizer, MarianForCausalLM + + >>> tokenizer = MarianTokenizer.from_pretrained('facebook/bart-large') + >>> model = MarianForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index e9b31d23a4..1944e1dac4 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -943,7 +943,7 @@ class TFMarianDecoder(tf.keras.layers.Layer): tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] ) - if inputs["attention_mask"] is not None and input_shape[-1] > 1: + if inputs["attention_mask"] is not None: combined_attention_mask = combined_attention_mask + _expand_mask( inputs["attention_mask"], tgt_len=input_shape[-1] ) diff --git a/src/transformers/models/mbart/__init__.py b/src/transformers/models/mbart/__init__.py index db2aca8494..e6f2d53b80 100644 --- a/src/transformers/models/mbart/__init__.py +++ b/src/transformers/models/mbart/__init__.py @@ -39,6 +39,7 @@ if is_tokenizers_available(): if is_torch_available(): _import_structure["modeling_mbart"] = [ "MBART_PRETRAINED_MODEL_ARCHIVE_LIST", + "MBartForCausalLM", "MBartForConditionalGeneration", "MBartForQuestionAnswering", "MBartForSequenceClassification", @@ -62,6 +63,7 @@ if TYPE_CHECKING: if is_torch_available(): from .modeling_mbart import ( MBART_PRETRAINED_MODEL_ARCHIVE_LIST, + MBartForCausalLM, MBartForConditionalGeneration, MBartForQuestionAnswering, MBartForSequenceClassification, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 333cd1bccc..11c7df5d7d 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch MBART model. """ - - +import copy import math import random from typing import Optional, Tuple @@ -36,6 +35,7 @@ from ...file_utils import ( from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, Seq2SeqQuestionAnsweringModelOutput, @@ -852,6 +852,31 @@ class MBartDecoder(MBartPreTrainedModel): self.init_weights() + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(self.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + def forward( self, input_ids=None, @@ -954,19 +979,9 @@ class MBartDecoder(MBartPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) - - if attention_mask is not None and combined_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = combined_attention_mask + _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: @@ -984,7 +999,7 @@ class MBartDecoder(MBartPreTrainedModel): # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - all_cross_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None # check if head_mask has a correct number of layers specified if desired @@ -1021,7 +1036,7 @@ class MBartDecoder(MBartPreTrainedModel): layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, - combined_attention_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -1032,7 +1047,7 @@ class MBartDecoder(MBartPreTrainedModel): layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), @@ -1048,7 +1063,9 @@ class MBartDecoder(MBartPreTrainedModel): if output_attentions: all_self_attns += (layer_outputs[1],) - all_cross_attentions += (layer_outputs[2],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) hidden_states = self.layer_norm(hidden_states) @@ -1570,3 +1587,210 @@ class MBartForQuestionAnswering(MBartPreTrainedModel): encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, ) + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->MBart +class MBartDecoderWrapper(MBartPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the :class:`~transformers.EncoderDecoderModel` framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = MBartDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart +class MBartForCausalLM(MBartPreTrainedModel): + def __init__(self, config): + super().__init__(config) + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + self.model = MBartDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.init_weights() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids`` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., + config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., + config.vocab_size]``. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + + Returns: + + Example:: + + >>> from transformers import MBartTokenizer, MBartForCausalLM + + >>> tokenizer = MBartTokenizer.from_pretrained('facebook/bart-large') + >>> model = MBartForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index cddd6917f5..6518b98836 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -938,7 +938,7 @@ class TFMBartDecoder(tf.keras.layers.Layer): tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] ) - if inputs["attention_mask"] is not None and input_shape[-1] > 1: + if inputs["attention_mask"] is not None: combined_attention_mask = combined_attention_mask + _expand_mask( inputs["attention_mask"], tgt_len=input_shape[-1] ) diff --git a/src/transformers/models/pegasus/__init__.py b/src/transformers/models/pegasus/__init__.py index e263c1b86c..50e6284be8 100644 --- a/src/transformers/models/pegasus/__init__.py +++ b/src/transformers/models/pegasus/__init__.py @@ -42,6 +42,7 @@ if is_torch_available(): "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel", + "PegasusForCausalLM", ] if is_tf_available(): @@ -60,6 +61,7 @@ if TYPE_CHECKING: if is_torch_available(): from .modeling_pegasus import ( PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST, + PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel, PegasusPreTrainedModel, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index f1a41d51d5..573f21f2d7 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -14,7 +14,7 @@ # limitations under the License. """ PyTorch PEGASUS model. """ - +import copy import math import random from typing import Optional, Tuple @@ -36,6 +36,7 @@ from ...file_utils import ( from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) @@ -817,6 +818,31 @@ class PegasusDecoder(PegasusPreTrainedModel): self.init_weights() + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(self.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + def forward( self, input_ids=None, @@ -919,19 +945,9 @@ class PegasusDecoder(PegasusPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) - - if attention_mask is not None and combined_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = combined_attention_mask + _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: @@ -948,7 +964,7 @@ class PegasusDecoder(PegasusPreTrainedModel): # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - all_cross_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None # check if head_mask has a correct number of layers specified if desired @@ -985,7 +1001,7 @@ class PegasusDecoder(PegasusPreTrainedModel): layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, - combined_attention_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -996,7 +1012,7 @@ class PegasusDecoder(PegasusPreTrainedModel): layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), @@ -1012,7 +1028,9 @@ class PegasusDecoder(PegasusPreTrainedModel): if output_attentions: all_self_attns += (layer_outputs[1],) - all_cross_attentions += (layer_outputs[2],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) hidden_states = self.layer_norm(hidden_states) @@ -1325,3 +1343,210 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus +class PegasusDecoderWrapper(PegasusPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the :class:`~transformers.EncoderDecoderModel` framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = PegasusDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Pegasus +class PegasusForCausalLM(PegasusPreTrainedModel): + def __init__(self, config): + super().__init__(config) + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + self.model = PegasusDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.init_weights() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids`` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., + config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., + config.vocab_size]``. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + + Returns: + + Example:: + + >>> from transformers import PegasusTokenizer, PegasusForCausalLM + + >>> tokenizer = PegasusTokenizer.from_pretrained('facebook/bart-large') + >>> model = PegasusForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index 3eb7b3f5a2..8eacf89613 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -955,7 +955,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer): tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] ) - if inputs["attention_mask"] is not None and input_shape[-1] > 1: + if inputs["attention_mask"] is not None: combined_attention_mask = combined_attention_mask + _expand_mask( inputs["attention_mask"], tgt_len=input_shape[-1] ) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index d473e8758a..f8cbb11f6f 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1330,7 +1330,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): >>> import torch >>> tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased') - >>> model = ProphetNetDecoder.from_pretrained('patrickvonplaten/prophetnet-large-uncased-standalone', add_cross_attention=False) + >>> model = ProphetNetDecoder.from_pretrained('microsoft/prophetnet-large-uncased', add_cross_attention=False) >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 995a87d5b9..b246074fc5 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -431,6 +431,11 @@ class AutoModelWithLMHead: BART_PRETRAINED_MODEL_ARCHIVE_LIST = None +class BartForCausalLM: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class BartForConditionalGeneration: def __init__(self, *args, **kwargs): requires_pytorch(self) @@ -596,6 +601,11 @@ def load_tf_weights_in_bert_generation(*args, **kwargs): BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None +class BlenderbotForCausalLM: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class BlenderbotForConditionalGeneration: def __init__(self, *args, **kwargs): requires_pytorch(self) @@ -617,6 +627,11 @@ class BlenderbotModel: BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = None +class BlenderbotSmallForCausalLM: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class BlenderbotSmallForConditionalGeneration: def __init__(self, *args, **kwargs): requires_pytorch(self) @@ -1464,6 +1479,11 @@ class LxmertXLayer: requires_pytorch(self) +class MarianForCausalLM: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class MarianModel: def __init__(self, *args, **kwargs): requires_pytorch(self) @@ -1482,6 +1502,11 @@ class MarianMTModel: requires_pytorch(self) +class MBartForCausalLM: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class MBartForConditionalGeneration: def __init__(self, *args, **kwargs): requires_pytorch(self) @@ -1772,6 +1797,11 @@ def load_tf_weights_in_openai_gpt(*args, **kwargs): requires_pytorch(load_tf_weights_in_openai_gpt) +class PegasusForCausalLM: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class PegasusForConditionalGeneration: def __init__(self, *args, **kwargs): requires_pytorch(self) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 0173e8e5bc..b025a2eac4 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -1522,6 +1522,7 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca ) {% else %} import math +import copy import random from typing import Optional, Tuple @@ -1663,6 +1664,7 @@ class {{cookiecutter.camelcase_modelname}}Attention(nn.Module): key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -1730,6 +1732,13 @@ class {{cookiecutter.camelcase_modelname}}Attention(nn.Module): attn_weights = F.softmax(attn_weights, dim=-1) + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + if output_attentions: # this operation is a bit akward, but it's required to # make sure that attn_weights keeps its gradient. @@ -1778,19 +1787,30 @@ class {{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module): self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ): """ Args: hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. """ residual = hidden_states hidden_states, attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -1849,6 +1869,8 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + encoder_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -1861,6 +1883,10 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(config.encoder_attention_heads,)`. + encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of + size `(config.encoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -1876,6 +1902,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1894,6 +1921,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=encoder_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -1924,7 +1952,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): return outputs -# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->{{cookiecutter.camelcase_modelname}} +# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}ClassificationHead with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}} class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" @@ -2036,6 +2064,18 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): If you want to change padding behavior, you should read :func:`modeling_{{cookiecutter.lowercase_modelname}}._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the paper `__ for more information on the default strategy. + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -2073,6 +2113,35 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): """ +{{cookiecutter.uppercase_modelname}}_STANDALONE_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_modelname}}PreTrainedModel): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a @@ -2113,6 +2182,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model self, input_ids=None, attention_mask=None, + head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, @@ -2136,6 +2206,11 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -2182,7 +2257,14 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - for encoder_layer in self.layers: + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + + for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -2202,9 +2284,15 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model create_custom_forward(encoder_layer), hidden_states, attention_mask, + (head_mask[idx] if head_mask is not None else None), ) else: - layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -2253,12 +2341,39 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model self.init_weights() + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(self.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + def forward( self, input_ids=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -2295,6 +2410,19 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -2340,19 +2468,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) - - if attention_mask is not None and combined_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = combined_attention_mask + _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: @@ -2370,8 +2486,15 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - all_cross_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -2398,18 +2521,22 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, - combined_attention_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + encoder_head_mask[idx] if encoder_head_mask is not None else None, None, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -2421,7 +2548,9 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model if output_attentions: all_self_attns += (layer_outputs[1],) - all_cross_attentions += (layer_outputs[2],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) # add hidden states from the last decoder layer if output_hidden_states: @@ -2486,6 +2615,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -2506,6 +2637,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -2525,6 +2657,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -2603,6 +2737,8 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -2649,6 +2785,8 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -2681,7 +2819,14 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte ) def prepare_inputs_for_generation( - self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs ): # cut decoder_input_ids if past is used if past is not None: @@ -2693,6 +2838,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } @@ -2919,4 +3065,210 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, ) + +# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}DecoderWrapper with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}} +class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcase_modelname}}PretrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the :class:`~transformers.EncoderDecoderModel` framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = {{cookiecutter.camelcase_modelname}}Decoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}ForCausalLM with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}} +class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_modelname}}PretrainedModel): + def __init__(self, config): + super().__init__(config) + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + self.model = {{cookiecutter.camelcase_modelname}}DecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.init_weights() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.{{cookiecutter.camelcase_modelname}}Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids`` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., + config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., + config.vocab_size]``. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + + Returns: + + Example:: + + >>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}ForCausalLM + + >>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}') + >>> model = {{cookiecutter.camelcase_modelname}}ForCausalLM.from_pretrained('{{cookiecutter.checkpoint_identifier}}', add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past {% endif -%} diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_{{cookiecutter.lowercase_modelname}}.py index 9d7d303df7..ad1a9f94b3 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_{{cookiecutter.lowercase_modelname}}.py @@ -488,7 +488,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers from .test_configuration_common import ConfigTester from .test_generation_utils import GenerationTesterMixin -from .test_modeling_common import ModelTesterMixin, ids_tensor +from .test_modeling_common import ModelTesterMixin, ids_tensor, floats_tensor if is_torch_available(): @@ -847,4 +847,220 @@ class {{cookiecutter.camelcase_modelname}}ModelIntegrationTests(unittest.TestCas hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True ) assert generated == EXPECTED + + +class {{cookiecutter.camelcase_modelname}}StandaloneDecoderModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + d_model=16, + decoder_seq_length=7, + is_training=True, + is_decoder=True, + use_attention_mask=True, + use_cache=False, + use_labels=True, + decoder_start_token_id=2, + decoder_ffn_dim=32, + decoder_layers=4, + encoder_attention_heads=4, + decoder_attention_heads=4, + max_position_embeddings=30, + is_encoder_decoder=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + + self.vocab_size = vocab_size + self.d_model = d_model + self.hidden_size = d_model + self.num_hidden_layers = decoder_layers + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.num_attention_heads = decoder_attention_heads + self.eos_token_id = eos_token_id + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.use_cache = use_cache + self.max_position_embeddings = max_position_embeddings + self.is_encoder_decoder = is_encoder_decoder + + self.scope = None + self.decoder_key_length = decoder_seq_length + self.base_model_out_len = 2 + self.decoder_attention_idx = 1 + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = {{cookiecutter.camelcase_modelname}}Config( + vocab_size=self.vocab_size, + d_model=self.d_model, + decoder_layers=self.decoder_layers, + decoder_ffn_dim=self.decoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + decoder_attention_heads=self.decoder_attention_heads, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + use_cache=self.use_cache, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + max_position_embeddings=self.max_position_embeddings, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return ( + config, + input_ids, + attention_mask, + lm_labels, + ) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + config.use_cache = True + model = {{cookiecutter.camelcase_modelname}}Decoder(config=config).to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + model = {{cookiecutter.camelcase_modelname}}Decoder(config=config).to(torch_device).eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class {{cookiecutter.camelcase_modelname}}StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}ForCausalLM) if is_torch_available() else () + all_generative_model_classes = ({{cookiecutter.camelcase_modelname}}ForCausalLM,) if is_torch_available() else () + test_pruning = False + is_encoder_decoder = False + + def setUp( + self, + ): + self.model_tester = {{cookiecutter.camelcase_modelname}}StandaloneDecoderModelTester(self, is_training=False) + self.config_tester = ConfigTester(self, config_class={{cookiecutter.camelcase_modelname}}Config) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_attn_mask_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_retain_grad_hidden_states_attentions(self): + # decoder cannot keep gradients + return {% endif -%} diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 9e5cb08430..2e15fa77d8 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -27,7 +27,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers from .test_configuration_common import ConfigTester from .test_generation_utils import GenerationTesterMixin -from .test_modeling_common import ModelTesterMixin, ids_tensor +from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor if is_torch_available(): @@ -36,6 +36,7 @@ if is_torch_available(): from transformers import ( AutoModelForSequenceClassification, BartConfig, + BartForCausalLM, BartForConditionalGeneration, BartForQuestionAnswering, BartForSequenceClassification, @@ -178,7 +179,7 @@ class BartModelTester: self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) def check_encoder_decoder_model_standalone(self, config, inputs_dict): model = BartModel(config=config).to(torch_device).eval() @@ -719,3 +720,242 @@ class BartModelIntegrationTests(unittest.TestCase): hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True ) assert generated_summaries == EXPECTED + + +class BartStandaloneDecoderModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + d_model=16, + decoder_seq_length=7, + is_training=True, + is_decoder=True, + use_attention_mask=True, + use_cache=False, + use_labels=True, + decoder_start_token_id=2, + decoder_ffn_dim=32, + decoder_layers=4, + encoder_attention_heads=4, + decoder_attention_heads=4, + max_position_embeddings=30, + is_encoder_decoder=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + + self.vocab_size = vocab_size + self.d_model = d_model + self.hidden_size = d_model + self.num_hidden_layers = decoder_layers + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.num_attention_heads = decoder_attention_heads + self.eos_token_id = eos_token_id + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.use_cache = use_cache + self.max_position_embeddings = max_position_embeddings + self.is_encoder_decoder = is_encoder_decoder + + self.scope = None + self.decoder_key_length = decoder_seq_length + self.base_model_out_len = 2 + self.decoder_attention_idx = 1 + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = BartConfig( + vocab_size=self.vocab_size, + d_model=self.d_model, + encoder_layers=self.decoder_layers, + decoder_layers=self.decoder_layers, + decoder_ffn_dim=self.decoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + decoder_attention_heads=self.decoder_attention_heads, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + use_cache=self.use_cache, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + max_position_embeddings=self.max_position_embeddings, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return ( + config, + input_ids, + attention_mask, + lm_labels, + ) + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + attention_mask, + lm_labels, + ) = self.prepare_config_and_inputs() + + encoder_hidden_states = floats_tensor([self.batch_size, self.decoder_seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + return ( + config, + input_ids, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + lm_labels, + ) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + config.use_cache = True + model = BartDecoder(config=config).to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + model = BartDecoder(config=config).to(torch_device).eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (BartDecoder, BartForCausalLM) if is_torch_available() else () + all_generative_model_classes = (BartForCausalLM,) if is_torch_available() else () + test_pruning = False + is_encoder_decoder = False + + def setUp( + self, + ): + self.model_tester = BartStandaloneDecoderModelTester(self, is_training=False) + self.config_tester = ConfigTester(self, config_class=BartConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_attn_mask_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_retain_grad_hidden_states_attentions(self): + # decoder cannot keep gradients + return diff --git a/tests/test_modeling_blenderbot.py b/tests/test_modeling_blenderbot.py index 974b99f84a..1e5337d65d 100644 --- a/tests/test_modeling_blenderbot.py +++ b/tests/test_modeling_blenderbot.py @@ -14,7 +14,6 @@ # limitations under the License. """ Testing suite for the PyTorch Blenderbot model. """ - import tempfile import unittest @@ -31,7 +30,11 @@ if is_torch_available(): import torch from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration, BlenderbotModel, BlenderbotTokenizer - from transformers.models.blenderbot.modeling_blenderbot import BlenderbotDecoder, BlenderbotEncoder + from transformers.models.blenderbot.modeling_blenderbot import ( + BlenderbotDecoder, + BlenderbotEncoder, + BlenderbotForCausalLM, + ) def prepare_blenderbot_inputs_dict( @@ -165,7 +168,7 @@ class BlenderbotModelTester: self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) def check_encoder_decoder_model_standalone(self, config, inputs_dict): model = BlenderbotModel(config=config).to(torch_device).eval() @@ -300,3 +303,222 @@ class Blenderbot3BIntegrationTests(unittest.TestCase): assert "I think it's because we are so worried about what people think of us." == reply.strip() del model + + +class BlenderbotStandaloneDecoderModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + d_model=16, + decoder_seq_length=7, + is_training=True, + is_decoder=True, + use_attention_mask=True, + use_cache=False, + use_labels=True, + decoder_start_token_id=2, + decoder_ffn_dim=32, + decoder_layers=4, + encoder_attention_heads=4, + decoder_attention_heads=4, + max_position_embeddings=30, + is_encoder_decoder=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + + self.vocab_size = vocab_size + self.d_model = d_model + self.hidden_size = d_model + self.num_hidden_layers = decoder_layers + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.num_attention_heads = decoder_attention_heads + self.eos_token_id = eos_token_id + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.use_cache = use_cache + self.max_position_embeddings = max_position_embeddings + self.is_encoder_decoder = is_encoder_decoder + + self.scope = None + self.decoder_key_length = decoder_seq_length + self.base_model_out_len = 2 + self.decoder_attention_idx = 1 + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = BlenderbotConfig( + vocab_size=self.vocab_size, + d_model=self.d_model, + decoder_layers=self.decoder_layers, + decoder_ffn_dim=self.decoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + decoder_attention_heads=self.decoder_attention_heads, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + use_cache=self.use_cache, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + max_position_embeddings=self.max_position_embeddings, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return ( + config, + input_ids, + attention_mask, + lm_labels, + ) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + config.use_cache = True + model = BlenderbotDecoder(config=config).to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + model = BlenderbotDecoder(config=config).to(torch_device).eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"] + # past_key_values = model(input_ids, use_cache=True)["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class BlenderbotStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (BlenderbotDecoder, BlenderbotForCausalLM) if is_torch_available() else () + all_generative_model_classes = (BlenderbotForCausalLM,) if is_torch_available() else () + test_pruning = False + is_encoder_decoder = False + + def setUp( + self, + ): + self.model_tester = BlenderbotStandaloneDecoderModelTester(self, is_training=False) + self.config_tester = ConfigTester(self, config_class=BlenderbotConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_attn_mask_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_retain_grad_hidden_states_attentions(self): + # decoder cannot keep gradients + return diff --git a/tests/test_modeling_blenderbot_small.py b/tests/test_modeling_blenderbot_small.py index eff4d5b545..a4c68085f6 100644 --- a/tests/test_modeling_blenderbot_small.py +++ b/tests/test_modeling_blenderbot_small.py @@ -14,7 +14,6 @@ # limitations under the License. """ Testing suite for the PyTorch BlenderbotSmall model. """ - import tempfile import unittest @@ -39,6 +38,7 @@ if is_torch_available(): from transformers.models.blenderbot_small.modeling_blenderbot_small import ( BlenderbotSmallDecoder, BlenderbotSmallEncoder, + BlenderbotSmallForCausalLM, ) @@ -173,7 +173,7 @@ class BlenderbotSmallModelTester: self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) def check_encoder_decoder_model_standalone(self, config, inputs_dict): model = BlenderbotSmallModel(config=config).to(torch_device).eval() @@ -317,3 +317,221 @@ class Blenderbot90MIntegrationTests(unittest.TestCase): "have you ever been to a sam club? it's a great club in the south.", "have you ever heard of sam harris? he's an american singer, songwriter, and actor.", ) + + +class BlenderbotSmallStandaloneDecoderModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + d_model=16, + decoder_seq_length=7, + is_training=True, + is_decoder=True, + use_attention_mask=True, + use_cache=False, + use_labels=True, + decoder_start_token_id=2, + decoder_ffn_dim=32, + decoder_layers=4, + encoder_attention_heads=4, + decoder_attention_heads=4, + max_position_embeddings=30, + is_encoder_decoder=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + + self.vocab_size = vocab_size + self.d_model = d_model + self.hidden_size = d_model + self.num_hidden_layers = decoder_layers + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.num_attention_heads = decoder_attention_heads + self.eos_token_id = eos_token_id + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.use_cache = use_cache + self.max_position_embeddings = max_position_embeddings + self.is_encoder_decoder = is_encoder_decoder + + self.scope = None + self.decoder_key_length = decoder_seq_length + self.base_model_out_len = 2 + self.decoder_attention_idx = 1 + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = BlenderbotSmallConfig( + vocab_size=self.vocab_size, + d_model=self.d_model, + decoder_layers=self.decoder_layers, + decoder_ffn_dim=self.decoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + decoder_attention_heads=self.decoder_attention_heads, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + use_cache=self.use_cache, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + max_position_embeddings=self.max_position_embeddings, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return ( + config, + input_ids, + attention_mask, + lm_labels, + ) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + config.use_cache = True + model = BlenderbotSmallDecoder(config=config).to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + model = BlenderbotSmallDecoder(config=config).to(torch_device).eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class BlenderbotSmallStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (BlenderbotSmallDecoder, BlenderbotSmallForCausalLM) if is_torch_available() else () + all_generative_model_classes = (BlenderbotSmallForCausalLM,) if is_torch_available() else () + test_pruning = False + is_encoder_decoder = False + + def setUp( + self, + ): + self.model_tester = BlenderbotSmallStandaloneDecoderModelTester(self, is_training=False) + self.config_tester = ConfigTester(self, config_class=BlenderbotSmallConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_attn_mask_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_retain_grad_hidden_states_attentions(self): + # decoder cannot keep gradients + return diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index 38eca3e809..338de796b9 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -20,6 +20,7 @@ import unittest from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device +from .test_modeling_bart import BartStandaloneDecoderModelTester from .test_modeling_bert import BertModelTester from .test_modeling_bert_generation import BertGenerationEncoderTester from .test_modeling_common import ids_tensor @@ -34,6 +35,7 @@ if is_torch_available(): from transformers import ( AutoTokenizer, + BartForCausalLM, BertGenerationDecoder, BertGenerationEncoder, BertLMHeadModel, @@ -828,3 +830,57 @@ class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): def test_encoder_decoder_model_shared_weights(self): pass + + +@require_torch +class BartEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): + def get_encoder_decoder_model(self, config, decoder_config): + encoder_model = BertModel(config) + decoder_model = BartForCausalLM(decoder_config) + return encoder_model, decoder_model + + def prepare_config_and_inputs(self): + model_tester_encoder = BertModelTester(self, batch_size=13) + model_tester_decoder = BartStandaloneDecoderModelTester( + self, batch_size=13, d_model=32, max_position_embeddings=512 + ) + encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs() + decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = encoder_config_and_inputs + ( + decoder_config, + decoder_input_ids, + decoder_attention_mask, + encoder_hidden_states, + encoder_attention_mask, + lm_labels, + ) = decoder_config_and_inputs + + # make sure that cross attention layers are added + decoder_config.add_cross_attention = True + # disable cache for now + decoder_config.use_cache = False + return { + "config": config, + "input_ids": input_ids, + "attention_mask": input_mask, + "decoder_config": decoder_config, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "encoder_hidden_states": encoder_hidden_states, + "labels": lm_labels, + } + + def get_pretrained_model(self): + return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "facebook/bart-large") + + def test_encoder_decoder_model_shared_weights(self): + pass diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index 621530a4b9..7da01e043b 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -14,7 +14,6 @@ # limitations under the License. """ Testing suite for the PyTorch Marian model. """ - import tempfile import unittest @@ -45,7 +44,12 @@ if is_torch_available(): convert_hf_name_to_opus_name, convert_opus_name_to_hf_name, ) - from transformers.models.marian.modeling_marian import MarianDecoder, MarianEncoder, shift_tokens_right + from transformers.models.marian.modeling_marian import ( + MarianDecoder, + MarianEncoder, + MarianForCausalLM, + shift_tokens_right, + ) def prepare_marian_inputs_dict( @@ -182,7 +186,7 @@ class MarianModelTester: self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) def check_encoder_decoder_model_standalone(self, config, inputs_dict): model = MarianModel(config=config).to(torch_device).eval() @@ -546,3 +550,221 @@ class TestConversionUtils(unittest.TestCase): "en-de", ] self.assertListEqual(expected_opus_names, converted_opus_names) + + +class MarianStandaloneDecoderModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + d_model=16, + decoder_seq_length=7, + is_training=True, + is_decoder=True, + use_attention_mask=True, + use_cache=False, + use_labels=True, + decoder_start_token_id=2, + decoder_ffn_dim=32, + decoder_layers=4, + encoder_attention_heads=4, + decoder_attention_heads=4, + max_position_embeddings=30, + is_encoder_decoder=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + + self.vocab_size = vocab_size + self.d_model = d_model + self.hidden_size = d_model + self.num_hidden_layers = decoder_layers + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.num_attention_heads = decoder_attention_heads + self.eos_token_id = eos_token_id + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.use_cache = use_cache + self.max_position_embeddings = max_position_embeddings + self.is_encoder_decoder = is_encoder_decoder + + self.scope = None + self.decoder_key_length = decoder_seq_length + self.base_model_out_len = 2 + self.decoder_attention_idx = 1 + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = MarianConfig( + vocab_size=self.vocab_size, + d_model=self.d_model, + decoder_layers=self.decoder_layers, + decoder_ffn_dim=self.decoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + decoder_attention_heads=self.decoder_attention_heads, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + use_cache=self.use_cache, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + max_position_embeddings=self.max_position_embeddings, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return ( + config, + input_ids, + attention_mask, + lm_labels, + ) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + config.use_cache = True + model = MarianDecoder(config=config).to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + model = MarianDecoder(config=config).to(torch_device).eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class MarianStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (MarianDecoder, MarianForCausalLM) if is_torch_available() else () + all_generative_model_classes = (MarianForCausalLM,) if is_torch_available() else () + test_pruning = False + is_encoder_decoder = False + + def setUp( + self, + ): + self.model_tester = MarianStandaloneDecoderModelTester(self, is_training=False) + self.config_tester = ConfigTester(self, config_class=MarianConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_attn_mask_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_retain_grad_hidden_states_attentions(self): + # decoder cannot keep gradients + return diff --git a/tests/test_modeling_mbart.py b/tests/test_modeling_mbart.py index 0e88ccb037..7355eb146a 100644 --- a/tests/test_modeling_mbart.py +++ b/tests/test_modeling_mbart.py @@ -35,6 +35,7 @@ if is_torch_available(): AutoTokenizer, BatchEncoding, MBartConfig, + MBartForCausalLM, MBartForConditionalGeneration, MBartForQuestionAnswering, MBartForSequenceClassification, @@ -174,7 +175,7 @@ class MBartModelTester: self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) def check_encoder_decoder_model_standalone(self, config, inputs_dict): model = MBartModel(config=config).to(torch_device).eval() @@ -431,3 +432,221 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest): outputs, clean_up_tokenization_spaces=True, skip_special_tokens=True )[0] self.assertEqual(prediction, "of the best books I ever read!") + + +class MBartStandaloneDecoderModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + d_model=16, + decoder_seq_length=7, + is_training=True, + is_decoder=True, + use_attention_mask=True, + use_cache=False, + use_labels=True, + decoder_start_token_id=2, + decoder_ffn_dim=32, + decoder_layers=4, + encoder_attention_heads=4, + decoder_attention_heads=4, + max_position_embeddings=30, + is_encoder_decoder=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + + self.vocab_size = vocab_size + self.d_model = d_model + self.hidden_size = d_model + self.num_hidden_layers = decoder_layers + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.num_attention_heads = decoder_attention_heads + self.eos_token_id = eos_token_id + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.use_cache = use_cache + self.max_position_embeddings = max_position_embeddings + self.is_encoder_decoder = is_encoder_decoder + + self.scope = None + self.decoder_key_length = decoder_seq_length + self.base_model_out_len = 2 + self.decoder_attention_idx = 1 + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = MBartConfig( + vocab_size=self.vocab_size, + d_model=self.d_model, + decoder_layers=self.decoder_layers, + decoder_ffn_dim=self.decoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + decoder_attention_heads=self.decoder_attention_heads, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + use_cache=self.use_cache, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + max_position_embeddings=self.max_position_embeddings, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return ( + config, + input_ids, + attention_mask, + lm_labels, + ) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + config.use_cache = True + model = MBartDecoder(config=config).to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + model = MBartDecoder(config=config).to(torch_device).eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class MBartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (MBartDecoder, MBartForCausalLM) if is_torch_available() else () + all_generative_model_classes = (MBartForCausalLM,) if is_torch_available() else () + test_pruning = False + is_encoder_decoder = False + + def setUp( + self, + ): + self.model_tester = MBartStandaloneDecoderModelTester(self, is_training=False) + self.config_tester = ConfigTester(self, config_class=MBartConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_attn_mask_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_retain_grad_hidden_states_attentions(self): + # decoder cannot keep gradients + return diff --git a/tests/test_modeling_pegasus.py b/tests/test_modeling_pegasus.py index 51c16726de..c0418d8c89 100644 --- a/tests/test_modeling_pegasus.py +++ b/tests/test_modeling_pegasus.py @@ -14,7 +14,6 @@ # limitations under the License. """ Testing suite for the PyTorch PEGASUS model. """ - import tempfile import unittest @@ -32,7 +31,7 @@ if is_torch_available(): import torch from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration, PegasusModel - from transformers.models.pegasus.modeling_pegasus import PegasusDecoder, PegasusEncoder + from transformers.models.pegasus.modeling_pegasus import PegasusDecoder, PegasusEncoder, PegasusForCausalLM def prepare_pegasus_inputs_dict( @@ -166,7 +165,7 @@ class PegasusModelTester: self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) def check_encoder_decoder_model_standalone(self, config, inputs_dict): model = PegasusModel(config=config).to(torch_device).eval() @@ -308,3 +307,221 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest): "California's largest electricity provider has begun", "N-Dubz have revealed they were", ] + + +class PegasusStandaloneDecoderModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + d_model=16, + decoder_seq_length=7, + is_training=True, + is_decoder=True, + use_attention_mask=True, + use_cache=False, + use_labels=True, + decoder_start_token_id=2, + decoder_ffn_dim=32, + decoder_layers=4, + encoder_attention_heads=4, + decoder_attention_heads=4, + max_position_embeddings=30, + is_encoder_decoder=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + + self.vocab_size = vocab_size + self.d_model = d_model + self.hidden_size = d_model + self.num_hidden_layers = decoder_layers + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.num_attention_heads = decoder_attention_heads + self.eos_token_id = eos_token_id + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.use_cache = use_cache + self.max_position_embeddings = max_position_embeddings + self.is_encoder_decoder = is_encoder_decoder + + self.scope = None + self.decoder_key_length = decoder_seq_length + self.base_model_out_len = 2 + self.decoder_attention_idx = 1 + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = PegasusConfig( + vocab_size=self.vocab_size, + d_model=self.d_model, + decoder_layers=self.decoder_layers, + decoder_ffn_dim=self.decoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + decoder_attention_heads=self.decoder_attention_heads, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + use_cache=self.use_cache, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + max_position_embeddings=self.max_position_embeddings, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return ( + config, + input_ids, + attention_mask, + lm_labels, + ) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + config.use_cache = True + model = PegasusDecoder(config=config).to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + model = PegasusDecoder(config=config).to(torch_device).eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class PegasusStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (PegasusDecoder, PegasusForCausalLM) if is_torch_available() else () + all_generative_model_classes = (PegasusForCausalLM,) if is_torch_available() else () + test_pruning = False + is_encoder_decoder = False + + def setUp( + self, + ): + self.model_tester = PegasusStandaloneDecoderModelTester(self, is_training=False) + self.config_tester = ConfigTester(self, config_class=PegasusConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_attn_mask_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_retain_grad_hidden_states_attentions(self): + # decoder cannot keep gradients + return diff --git a/utils/check_repo.py b/utils/check_repo.py index 54f23a5d08..d0ca63ef1f 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -32,17 +32,17 @@ IGNORE_NON_TESTED = [ # models to ignore for not tested "LEDEncoder", # Building part of bigger (tested) model. "LEDDecoder", # Building part of bigger (tested) model. - "BartDecoder", # Building part of bigger (tested) model. + "BartDecoderWrapper", # Building part of bigger (tested) model. "BartEncoder", # Building part of bigger (tested) model. "BertLMHeadModel", # Needs to be setup as decoder. "BlenderbotSmallEncoder", # Building part of bigger (tested) model. - "BlenderbotSmallDecoder", # Building part of bigger (tested) model. + "BlenderbotSmallDecoderWrapper", # Building part of bigger (tested) model. "BlenderbotEncoder", # Building part of bigger (tested) model. - "BlenderbotDecoder", # Building part of bigger (tested) model. + "BlenderbotDecoderWrapper", # Building part of bigger (tested) model. "MBartEncoder", # Building part of bigger (tested) model. - "MBartDecoder", # Building part of bigger (tested) model. + "MBartDecoderWrapper", # Building part of bigger (tested) model. "PegasusEncoder", # Building part of bigger (tested) model. - "PegasusDecoder", # Building part of bigger (tested) model. + "PegasusDecoderWrapper", # Building part of bigger (tested) model. "DPREncoder", # Building part of bigger (tested) model. "DPRSpanPredictor", # Building part of bigger (tested) model. "ProphetNetDecoderWrapper", # Building part of bigger (tested) model. @@ -78,11 +78,14 @@ IGNORE_NON_AUTO_CONFIGURED = [ "LEDEncoder", "LEDDecoder", "BartDecoder", + "BartDecoderWrapper", "BartEncoder", "BlenderbotSmallEncoder", "BlenderbotSmallDecoder", + "BlenderbotSmallDecoderWrapper", "BlenderbotEncoder", "BlenderbotDecoder", + "BlenderbotDecoderWrapper", "DPRContextEncoder", "DPREncoder", "DPRReader", @@ -93,9 +96,11 @@ IGNORE_NON_AUTO_CONFIGURED = [ "MT5EncoderModel", "MBartEncoder", "MBartDecoder", + "MBartDecoderWrapper", "OpenAIGPTDoubleHeadsModel", "PegasusEncoder", "PegasusDecoder", + "PegasusDecoderWrapper", "ProphetNetDecoder", "ProphetNetEncoder", "ProphetNetDecoderWrapper", @@ -205,9 +210,8 @@ def find_tested_models(test_file): with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f: content = f.read() all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content) - # Check with one less parenthesis - if len(all_models) == 0: - all_models = re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content) + # Check with one less parenthesis as well + all_models += re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content) if len(all_models) > 0: model_tested = [] for entry in all_models: