From afc4ece462ad83a090af620ff4da099a0272e171 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 1 Sep 2020 12:38:25 +0200 Subject: [PATCH] [Generate] Facilitate PyTorch generate using `ModelOutputs` (#6735) * fix generate for GPT2 Double Head * fix gpt2 double head model * fix bart / t5 * also add for no beam search * fix no beam search * fix encoder decoder * simplify t5 * simplify t5 * fix t5 tests * fix BART * fix transfo-xl * fix conflict * integrating sylvains and sams comments * fix tf past_decoder_key_values * fix enc dec test --- docs/source/model_doc/encoderdecoder.rst | 9 +- src/transformers/generation_utils.py | 70 +++++---- src/transformers/modeling_bart.py | 80 +++++----- src/transformers/modeling_encoder_decoder.py | 151 ++++++++++++------- src/transformers/modeling_gpt2.py | 23 ++- src/transformers/modeling_openai.py | 12 +- src/transformers/modeling_outputs.py | 26 ++-- src/transformers/modeling_t5.py | 68 ++++----- src/transformers/modeling_tf_gpt2.py | 6 +- src/transformers/modeling_tf_openai.py | 6 +- src/transformers/modeling_tf_outputs.py | 26 ++-- src/transformers/modeling_tf_t5.py | 112 +++++++++----- src/transformers/modeling_transfo_xl.py | 9 ++ tests/test_modeling_encoder_decoder.py | 26 +++- tests/test_modeling_gpt2.py | 6 +- tests/test_modeling_openai.py | 4 +- tests/test_modeling_t5.py | 12 +- tests/test_modeling_tf_gpt2.py | 2 +- tests/test_modeling_tf_openai.py | 2 +- tests/test_modeling_tf_t5.py | 2 +- 20 files changed, 393 insertions(+), 259 deletions(-) diff --git a/docs/source/model_doc/encoderdecoder.rst b/docs/source/model_doc/encoderdecoder.rst index f3105d9131..a63b6044a2 100644 --- a/docs/source/model_doc/encoderdecoder.rst +++ b/docs/source/model_doc/encoderdecoder.rst @@ -1,12 +1,13 @@ Encoder Decoder Models ------------------------ -This class can wrap an encoder model, such as ``BertModel`` and a decoder modeling with a language modeling head, such as ``BertForMaskedLM`` into a encoder-decoder model. +The :class:`~transformers.EncoderDecoderModel` can be used to initialize a sequence-to-sequence model with any pre-trained autoencoding model as the encoder and any pre-trained autoregressive model as the decoder. -The ``EncoderDecoderModel`` class allows to instantiate a encoder decoder model using the ``from_encoder_decoder_pretrain`` class method taking a pretrained encoder and pretrained decoder model as an input. -The ``EncoderDecoderModel`` is saved using the standard ``save_pretrained()`` method and can also again be loaded using the standard ``from_pretrained()`` method. +The effectiveness of initializing sequence-to-sequence models with pre-trained checkpoints for sequence generation tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks `__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. -An application of this architecture could be *summarization* using two pretrained Bert models as is shown in the paper: `Text Summarization with Pretrained Encoders `_ by Yang Liu and Mirella Lapata. +After such an :class:`~transformers.EncoderDecoderModel` has been trained / fine-tuned, it can be saved / loaded just like any other models (see Examples for more information). + +An application of this architecture could be to leverage two pre-trained :obj:`transformers.BertModel` models as the encoder and decoder for a summarization model as was shown in: `Text Summarization with Pretrained Encoders `_ by Yang Liu and Mirella Lapata. ``EncoderDecoderConfig`` diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 302fad2fc4..638bb3b12e 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -20,6 +20,7 @@ import torch from torch import Tensor from torch.nn import functional as F +from .file_utils import ModelOutput from .utils import logging @@ -46,14 +47,6 @@ class GenerationMixin: """ return logits - def _use_cache(self, outputs, use_cache): - """During generation, decide whether to pass the `past` variable to the next forward pass.""" - if len(outputs) <= 1 or use_cache is False: - return False - if hasattr(self.config, "mem_len") and self.config.mem_len == 0: - return False - return True - def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty): """ Enforce the repetition penalty (from the `CTRL paper `__). @@ -137,7 +130,7 @@ class GenerationMixin: attention_mask: Optional[torch.LongTensor] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, - **model_specific_kwargs + **model_kwargs ) -> torch.LongTensor: r""" Generates sequences for models with a language modeling head. The method currently supports greedy decoding, @@ -208,7 +201,7 @@ class GenerationMixin: use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. - model_specific_kwargs: + model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. Return: @@ -400,7 +393,7 @@ class GenerationMixin: # get encoder and store encoder outputs encoder = self.get_encoder() - encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask) + encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True) # Expand input ids if num_beams > 1 or num_return_sequences > 1 if num_return_sequences > 1 or num_beams > 1: @@ -428,8 +421,8 @@ class GenerationMixin: cur_len = 1 assert ( - batch_size == encoder_outputs[0].shape[0] - ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} " + batch_size == encoder_outputs.last_hidden_state.shape[0] + ), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} " # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) expanded_batch_idxs = ( @@ -439,11 +432,16 @@ class GenerationMixin: .view(-1) .to(input_ids.device) ) + # expand encoder_outputs - encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:]) + encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( + 0, expanded_batch_idxs + ) + + # save encoder_outputs in `model_kwargs` + model_kwargs["encoder_outputs"] = encoder_outputs else: - encoder_outputs = None cur_len = input_ids.shape[-1] assert ( @@ -471,10 +469,9 @@ class GenerationMixin: length_penalty=length_penalty, num_beams=num_beams, vocab_size=vocab_size, - encoder_outputs=encoder_outputs, attention_mask=attention_mask, use_cache=use_cache, - model_specific_kwargs=model_specific_kwargs, + model_kwargs=model_kwargs, ) else: output = self._generate_no_beam_search( @@ -492,10 +489,9 @@ class GenerationMixin: pad_token_id=pad_token_id, eos_token_id=eos_token_id, batch_size=effective_batch_size, - encoder_outputs=encoder_outputs, attention_mask=attention_mask, use_cache=use_cache, - model_specific_kwargs=model_specific_kwargs, + model_kwargs=model_kwargs, ) return output @@ -516,10 +512,9 @@ class GenerationMixin: pad_token_id, eos_token_id, batch_size, - encoder_outputs, attention_mask, use_cache, - model_specific_kwargs, + model_kwargs, ): """Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated independantly. @@ -528,15 +523,14 @@ class GenerationMixin: unfinished_sents = input_ids.new(batch_size).fill_(1) sent_lengths = input_ids.new(batch_size).fill_(max_length) - past = (encoder_outputs, None) if encoder_outputs is not None else None - + past = None while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation( - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs ) - outputs = self(**model_inputs) - next_token_logits = outputs[0][:, -1, :] + outputs = self(**model_inputs, return_dict=True) + next_token_logits = outputs.logits[:, -1, :] scores = self.postprocess_next_token_scores( scores=next_token_logits, @@ -553,8 +547,10 @@ class GenerationMixin: ) # if model has past, then set the past variable to speed up decoding - if self._use_cache(outputs, use_cache): - past = outputs[1] + if "past_key_values" in outputs: + past = outputs.past_key_values + elif "mems" in outputs: + past = outputs.mems if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens) @@ -621,10 +617,9 @@ class GenerationMixin: length_penalty, num_beams, vocab_size, - encoder_outputs, attention_mask, use_cache, - model_specific_kwargs, + model_kwargs, ): """Generate sequences for each example with beam search.""" @@ -643,21 +638,24 @@ class GenerationMixin: beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) # cache compute states - past = (encoder_outputs, None) if encoder_outputs is not None else None + past = None # done sentences done = [False for _ in range(batch_size)] while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation( - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs ) - outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) - next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) + outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size) + next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size) # if model has past, then set the past variable to speed up decoding - if self._use_cache(outputs, use_cache): - past = outputs[1] + if "past_key_values" in outputs: + past = outputs.past_key_values + elif "mems" in outputs: + past = outputs.mems + if self.config.is_encoder_decoder and do_sample is False: # TODO (PVP) still a bit hacky here - there might be a better solution next_token_logits = self.adjust_logits_during_generation( diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 45b40554cd..4122d3aa9d 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -111,15 +111,15 @@ BART_INPUTS_DOCSTRING = r""" Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify. See diagram 1 in the paper for more info on the default strategy - decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains pre-computed key and value hidden-states of the attention blocks. Can be used to speed up decoding. - If ``decoder_past_key_value_states`` are used, the user can optionally input only the last + If ``past_key_values`` are used, the user can optionally input only the last ``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): - If `use_cache` is True, ``decoder_past_key_values`` are returned and can be used to speed up decoding (see - ``decoder_past_key_values``). + If `use_cache` is True, ``past_key_values`` are returned and can be used to speed up decoding (see + ``past_key_values``). output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`): @@ -502,7 +502,7 @@ class BartDecoder(nn.Module): encoder_padding_mask, decoder_padding_mask, decoder_causal_mask, - decoder_past_key_values=None, + past_key_values=None, use_cache=False, output_attentions=False, output_hidden_states=False, @@ -519,7 +519,7 @@ class BartDecoder(nn.Module): encoder_hidden_states: output from the encoder, used for encoder-side attention encoder_padding_mask: for ignoring pad tokens - decoder_past_key_values (dict or None): dictionary used for storing state during generation + past_key_values (dict or None): dictionary used for storing state during generation Returns: BaseModelOutputWithPast or tuple: @@ -530,10 +530,16 @@ class BartDecoder(nn.Module): """ if "decoder_cached_states" in unused: warnings.warn( - "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.", + "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", FutureWarning, ) - decoder_past_key_values = unused.pop("decoder_cached_states") + past_key_values = unused.pop("decoder_cached_states") + if "decoder_past_key_values" in unused: + warnings.warn( + "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + past_key_values = unused.pop("decoder_past_key_values") # check attention mask and invert if encoder_padding_mask is not None: @@ -568,7 +574,7 @@ class BartDecoder(nn.Module): if self.training and (dropout_probability < self.layerdrop): continue - layer_state = decoder_past_key_values[idx] if decoder_past_key_values is not None else None + layer_state = past_key_values[idx] if past_key_values is not None else None x, layer_self_attn, layer_past = decoder_layer( x, @@ -594,10 +600,7 @@ class BartDecoder(nn.Module): x = x.transpose(0, 1) encoder_hidden_states = encoder_hidden_states.transpose(0, 1) - if use_cache: - next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache) - else: - next_cache = None + next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -869,13 +872,19 @@ class BartModel(PretrainedBartModel): decoder_input_ids=None, encoder_outputs: Optional[Tuple] = None, decoder_attention_mask=None, - decoder_past_key_values=None, + past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): + if "decoder_past_key_values" in kwargs: + warnings.warn( + "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + past_key_values = kwargs.pop("decoder_past_key_values") if decoder_input_ids is None: use_cache = False @@ -924,7 +933,7 @@ class BartModel(PretrainedBartModel): attention_mask, decoder_padding_mask, decoder_causal_mask=causal_mask, - decoder_past_key_values=decoder_past_key_values, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -936,7 +945,7 @@ class BartModel(PretrainedBartModel): return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, - decoder_past_key_values=decoder_outputs.past_key_values, + past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, @@ -994,7 +1003,7 @@ class BartForConditionalGeneration(PretrainedBartModel): encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, - decoder_past_key_values=None, + past_key_values=None, labels=None, use_cache=None, output_attentions=None, @@ -1037,10 +1046,16 @@ class BartForConditionalGeneration(PretrainedBartModel): labels = unused.pop("lm_labels") if "decoder_cached_states" in unused: warnings.warn( - "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.", + "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", FutureWarning, ) - decoder_past_key_values = unused.pop("decoder_cached_states") + past_key_values = unused.pop("decoder_cached_states") + if "decoder_past_key_values" in unused: + warnings.warn( + "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + past_key_values = unused.pop("decoder_past_key_values") return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: @@ -1054,7 +1069,7 @@ class BartForConditionalGeneration(PretrainedBartModel): decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, - decoder_past_key_values=decoder_past_key_values, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1075,7 +1090,7 @@ class BartForConditionalGeneration(PretrainedBartModel): return Seq2SeqLMOutput( loss=masked_lm_loss, logits=lm_logits, - decoder_past_key_values=outputs.decoder_past_key_values, + past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, @@ -1083,14 +1098,13 @@ class BartForConditionalGeneration(PretrainedBartModel): encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs): - assert past is not None, "past has to be defined for encoder_outputs" - - encoder_outputs, decoder_past_key_values = past + def prepare_inputs_for_generation( + self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs + ): return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, - "decoder_past_key_values": decoder_past_key_values, + "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) @@ -1109,20 +1123,14 @@ class BartForConditionalGeneration(PretrainedBartModel): @staticmethod def _reorder_cache(past, beam_idx): - ((enc_out, enc_mask), decoder_past_key_values) = past reordered_past = [] - for layer_past in decoder_past_key_values: + for layer_past in past: # get the correct batch idx from decoder layer's batch dim for cross and self-attn layer_past_new = { attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() } reordered_past.append(layer_past_new) - - new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx) - new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx) - - past = ((new_enc_out, new_enc_mask), reordered_past) - return past + return reordered_past def get_encoder(self): return self.model.encoder @@ -1208,7 +1216,7 @@ class BartForSequenceClassification(PretrainedBartModel): return Seq2SeqSequenceClassifierOutput( loss=loss, logits=logits, - decoder_past_key_values=outputs.decoder_past_key_values, + past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, @@ -1316,7 +1324,7 @@ class BartForQuestionAnswering(PretrainedBartModel): loss=total_loss, start_logits=start_logits, end_logits=end_logits, - decoder_past_key_values=outputs.decoder_past_key_values, + past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index b737fa7791..343c65321a 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -19,13 +19,79 @@ from typing import Optional from .configuration_encoder_decoder import EncoderDecoderConfig from .configuration_utils import PretrainedConfig +from .file_utils import add_start_docstrings, add_start_docstrings_to_callable, replace_return_docstrings +from .modeling_outputs import Seq2SeqLMOutput from .modeling_utils import PreTrainedModel from .utils import logging logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "EncoderDecoderConfig" +ENCODER_DECODER_START_DOCSTRING = r""" + This class can be used to inialize a sequence-to-sequnece model with any pretrained autoencoding model as the encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via :meth:`~transformers.AutoModel.from_pretrained` function and the decoder is loaded via :meth:`~transformers.AutoModelForCausalLM.from_pretrained` function. + Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream generative task, *i.e.* summarization. + + The effectiveness of initializing sequence-to-sequence models with pre-trained checkpoints for sequence generation tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks `__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. + Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. + + After such an Encoder Decoder model has been trained / fine-tuned, it can be saved / loaded just like any other models (see Examples for more information). + + This model is a PyTorch `torch.nn.Module `__ sub-class. Use it as a + regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. + + Parameters: + config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +ENCODER_DECODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary for the encoder. + Indices can be obtained using :class:`~transformers.PretrainedTokenizer`. + See :meth:`~transformers.PreTrainedTokenizer.encode` and + :meth:`~transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on padding token indices for the encoder. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + encoder_outputs (:obj:`tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`): + This tuple must consist of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) + `last_hidden_state` (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`) is a tensor of hidden-states at the output of the last layer of the encoder. + Used in the cross-attention of the decoder. + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): + Provide for sequence to sequence training to the decoder. + Indices can be obtained using :class:`transformers.PretrainedTokenizer`. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. + decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): + Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. + decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the masked language modeling loss for the decoder. + Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`): + If set to ``True``, the model will return a :class:`~transformers.file_utils.Seq2SeqLMOutput` instead of a + plain tuple. + kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: + - Without a prefix which will be input as ``**encoder_kwargs`` for the encoder forward function. + - With a `decoder_` prefix which will be input as ``**decoder_kwargs`` for the decoder forward function. +""" + + +@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) class EncoderDecoderModel(PreTrainedModel): r""" :class:`~transformers.EncoderDecoder` is a generic model class that will be @@ -206,6 +272,8 @@ class EncoderDecoderModel(PreTrainedModel): config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) return cls(encoder=encoder, decoder=decoder, config=config) + @add_start_docstrings_to_callable(ENCODER_DECODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids=None, @@ -216,47 +284,11 @@ class EncoderDecoderModel(PreTrainedModel): decoder_attention_mask=None, decoder_inputs_embeds=None, labels=None, + return_dict=None, **kwargs, ): - - """ - Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary for the encoder. - Indices can be obtained using :class:`transformers.PretrainedTokenizer`. - See :func:`transformers.PreTrainedTokenizer.encode` and - :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. - inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): - Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Mask to avoid performing attention on padding token indices for the encoder. - Mask values selected in ``[0, 1]``: - ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. - encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`): - Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`) - `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder. - Used in the cross-attention of the decoder. - decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): - Provide for sequence to sequence training to the decoder. - Indices can be obtained using :class:`transformers.PretrainedTokenizer`. - See :func:`transformers.PreTrainedTokenizer.encode` and - :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. - decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): - Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. - decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): - Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): - Labels for computing the masked language modeling loss for the decoder. - Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) - Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels - in ``[0, ..., config.vocab_size]`` - kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: - - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function. - - With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function. + r""" + Returns: Examples:: @@ -264,19 +296,25 @@ class EncoderDecoderModel(PreTrainedModel): >>> import torch >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert + >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints >>> # forward >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids) >>> # training - >>> loss, outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)[:2] + >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids, return_dict=True) + >>> loss, logits = outputs.loss, outputs.logits + + >>> # save and load from pretrained + >>> model.save_pretrained("bert2bert") + >>> model = EncoderDecoderModel.from_pretrained("bert2bert") >>> # generation >>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id) """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} @@ -289,7 +327,7 @@ class EncoderDecoderModel(PreTrainedModel): input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, - return_dict=False, + return_dict=return_dict, **kwargs_encoder, ) @@ -303,23 +341,28 @@ class EncoderDecoderModel(PreTrainedModel): encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, labels=labels, - return_dict=False, + return_dict=return_dict, **kwargs_decoder, ) # TODO(PVP): currently it is not possible to use `past` - # with the encoder/decoder framework -> should be implemented + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqLMOutput( + loss=decoder_outputs.loss, + logits=decoder_outputs.logits, + past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + return decoder_outputs + encoder_outputs - def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs): - assert past is not None, "past has to be defined for encoder_outputs" - - # first step - if type(past) is tuple: - encoder_outputs, _ = past - else: - encoder_outputs = (past,) - + def prepare_inputs_for_generation(self, input_ids, past, attention_mask, encoder_outputs, **kwargs): decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids) decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None input_dict = { @@ -335,7 +378,7 @@ class EncoderDecoderModel(PreTrainedModel): input_dict["decoder_use_cache"] = decoder_inputs["use_cache"] if "past_key_values" in decoder_inputs: - input_dict["decoder_past_key_values"] = decoder_inputs["past_key_values"] + input_dict["past_key_values"] = decoder_inputs["past_key_values"] return input_dict diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 1d4ceb0e2f..727a3a87c3 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -353,11 +353,11 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): Base class for outputs of models predicting if two sentences are consecutive or not. Args: - lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided): + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided): Language modeling loss. mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided): Multiple choice classification loss. - lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). @@ -380,9 +380,9 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): heads. """ - lm_loss: Optional[torch.FloatTensor] = None + loss: Optional[torch.FloatTensor] = None mc_loss: Optional[torch.FloatTensor] = None - lm_logits: torch.FloatTensor = None + logits: torch.FloatTensor = None mc_logits: torch.FloatTensor = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None @@ -777,6 +777,17 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): def get_output_embeddings(self): return self.lm_head + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + } + @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -893,9 +904,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): return ((lm_loss,) + output) if lm_loss is not None else output return GPT2DoubleHeadsModelOutput( - lm_loss=lm_loss, + loss=lm_loss, mc_loss=mc_loss, - lm_logits=lm_logits, + logits=lm_logits, mc_logits=mc_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, diff --git a/src/transformers/modeling_openai.py b/src/transformers/modeling_openai.py index 1920880b28..e62d13455d 100644 --- a/src/transformers/modeling_openai.py +++ b/src/transformers/modeling_openai.py @@ -300,11 +300,11 @@ class OpenAIGPTDoubleHeadsModelOutput(ModelOutput): Base class for outputs of models predicting if two sentences are consecutive or not. Args: - lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided): + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided): Language modeling loss. mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided): Multiple choice classification loss. - lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). @@ -321,9 +321,9 @@ class OpenAIGPTDoubleHeadsModelOutput(ModelOutput): heads. """ - lm_loss: Optional[torch.FloatTensor] = None + loss: Optional[torch.FloatTensor] = None mc_loss: Optional[torch.FloatTensor] = None - lm_logits: torch.FloatTensor = None + logits: torch.FloatTensor = None mc_logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -713,9 +713,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): return ((lm_loss,) + output) if lm_loss is not None else output return OpenAIGPTDoubleHeadsModelOutput( - lm_loss=lm_loss, + loss=lm_loss, mc_loss=mc_loss, - lm_logits=lm_logits, + logits=lm_logits, mc_logits=mc_logits, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index 1c36dc2d81..e6a4b0ed8b 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -109,13 +109,13 @@ class Seq2SeqModelOutput(ModelOutput): last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the decoder of the model. - If ``decoder_past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. - decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + If ``past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. + past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``decoder_past_key_values`` input) to speed up sequential decoding. + used (see ``past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -143,7 +143,7 @@ class Seq2SeqModelOutput(ModelOutput): """ last_hidden_state: torch.FloatTensor - decoder_past_key_values: Optional[List[torch.FloatTensor]] = None + past_key_values: Optional[List[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None @@ -255,12 +255,12 @@ class Seq2SeqLMOutput(ModelOutput): Languaged modeling loss. logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``decoder_past_key_values`` input) to speed up sequential decoding. + used (see ``past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -289,7 +289,7 @@ class Seq2SeqLMOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None - decoder_past_key_values: Optional[List[torch.FloatTensor]] = None + past_key_values: Optional[List[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None @@ -365,12 +365,12 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput): Classification (or regression if config.num_labels==1) loss. logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). - decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``decoder_past_key_values`` input) to speed up sequential decoding. + used (see ``past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -399,7 +399,7 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None - decoder_past_key_values: Optional[List[torch.FloatTensor]] = None + past_key_values: Optional[List[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None @@ -511,12 +511,12 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): Span-start scores (before SoftMax). end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): Span-end scores (before SoftMax). - decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``decoder_past_key_values`` input) to speed up sequential decoding. + used (see ``past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -546,7 +546,7 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None start_logits: torch.FloatTensor = None end_logits: torch.FloatTensor = None - decoder_past_key_values: Optional[List[torch.FloatTensor]] = None + past_key_values: Optional[List[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index a8ae72d0b2..463d9f471e 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -838,27 +838,27 @@ T5_INPUTS_DOCSTRING = r""" Used in the cross-attention of the decoder. decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation. - If `decoder_past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_values`). + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at `T5 Training <./t5.html#training>`__. If decoder_input_ids and decoder_inputs_embeds are both None, decoder_input_ids takes the value of input_ids. decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. - decoder_past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains pre-computed key and value hidden-states of the attention blocks. Can be used to speed up decoding. - If `decoder_past_key_values` are used, the user can optionally input only the last `decoder_input_ids` + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): - If `use_cache` is True, `decoder_past_key_values` are returned and can be used to speed up decoding (see `decoder_past_key_values`). + If `use_cache` is True, `past_key_values` are returned and can be used to speed up decoding (see `past_key_values`). inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation. - If `decoder_past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_values`). + If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`). This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. If decoder_input_ids and decoder_inputs_embeds are both None, decoder_inputs_embeds takes the value of inputs_embeds. @@ -928,7 +928,7 @@ class T5Model(T5PreTrainedModel): encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, - decoder_past_key_values=None, + past_key_values=None, use_cache=None, inputs_embeds=None, decoder_inputs_embeds=None, @@ -955,10 +955,16 @@ class T5Model(T5PreTrainedModel): """ if "decoder_past_key_value_states" in kwargs: warnings.warn( - "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.", + "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", FutureWarning, ) - decoder_past_key_values = kwargs.pop("decoder_past_key_value_states") + past_key_values = kwargs.pop("decoder_past_key_value_states") + if "decoder_past_key_values" in kwargs: + warnings.warn( + "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + past_key_values = kwargs.pop("decoder_past_key_values") assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -992,7 +998,7 @@ class T5Model(T5PreTrainedModel): # If decoding with past key value states, only the last tokens # should be given as an input - if decoder_past_key_values is not None: + if past_key_values is not None: if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: @@ -1003,7 +1009,7 @@ class T5Model(T5PreTrainedModel): input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, - past_key_value_states=decoder_past_key_values, + past_key_value_states=past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=head_mask, @@ -1013,15 +1019,12 @@ class T5Model(T5PreTrainedModel): return_dict=return_dict, ) - past = (encoder_outputs, decoder_outputs[1]) if use_cache is True else None if not return_dict: - if past is not None: - decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] return decoder_outputs + encoder_outputs return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, - decoder_past_key_values=past, + past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, @@ -1080,7 +1083,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, - decoder_past_key_values=None, + past_key_values=None, use_cache=None, labels=None, inputs_embeds=None, @@ -1127,10 +1130,16 @@ class T5ForConditionalGeneration(T5PreTrainedModel): labels = kwargs.pop("lm_labels") if "decoder_past_key_value_states" in kwargs: warnings.warn( - "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.", + "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", FutureWarning, ) - decoder_past_key_values = kwargs.pop("decoder_past_key_value_states") + past_key_values = kwargs.pop("decoder_past_key_value_states") + if "decoder_past_key_values" in kwargs: + warnings.warn( + "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + past_key_values = kwargs.pop("decoder_past_key_values") assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -1163,7 +1172,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): # If decoding with past key value states, only the last tokens # should be given as an input - if decoder_past_key_values is not None: + if past_key_values is not None: assert labels is None, "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] @@ -1175,7 +1184,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, - past_key_value_states=decoder_past_key_values, + past_key_value_states=past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=head_mask, @@ -1197,17 +1206,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel): loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 - past = (encoder_outputs, decoder_outputs[1]) if use_cache is True else None if not return_dict: - if past is not None: - decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs return ((loss,) + output) if loss is not None else output return Seq2SeqLMOutput( loss=loss, logits=lm_logits, - decoder_past_key_values=past, + past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, @@ -1215,14 +1221,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel): encoder_attentions=encoder_outputs.attentions, ) - def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs): - assert past is not None, "past has to be defined for encoder_outputs" - - encoder_outputs, decoder_past_key_values = past - + def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs): return { "decoder_input_ids": input_ids, - "decoder_past_key_values": decoder_past_key_values, + "past_key_values": past, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, "use_cache": use_cache, @@ -1231,14 +1233,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel): def _reorder_cache(self, past, beam_idx): # if decoder past is not included in output # speedy decoding is disabled and no need to reorder - if past[1] is None: + if past is None: logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") return past - decoder_past = past[1] - past = (past[0],) reordered_decoder_past = () - for layer_past_states in decoder_past: + for layer_past_states in past: # get the correct batch idx from layer past batch dim # batch dim of `past` is at 2nd position reordered_layer_past_states = () @@ -1252,4 +1252,4 @@ class T5ForConditionalGeneration(T5PreTrainedModel): assert len(reordered_layer_past_states) == len(layer_past_states) reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return past + (reordered_decoder_past,) + return reordered_decoder_past diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index e603643c25..439e2906bc 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -431,7 +431,7 @@ class TFGPT2DoubleHeadsModelOutput(ModelOutput): Base class for outputs of models predicting if two sentences are consecutive or not. Args: - lm_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): + logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). mc_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`): Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). @@ -454,7 +454,7 @@ class TFGPT2DoubleHeadsModelOutput(ModelOutput): heads. """ - lm_logits: tf.Tensor = None + logits: tf.Tensor = None mc_logits: tf.Tensor = None past_key_values: Optional[List[tf.Tensor]] = None hidden_states: Optional[Tuple[tf.Tensor]] = None @@ -794,7 +794,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): return (lm_logits, mc_logits) + transformer_outputs[1:] return TFGPT2DoubleHeadsModelOutput( - lm_logits=lm_logits, + logits=lm_logits, mc_logits=mc_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, diff --git a/src/transformers/modeling_tf_openai.py b/src/transformers/modeling_tf_openai.py index 49ca4de86c..0585968457 100644 --- a/src/transformers/modeling_tf_openai.py +++ b/src/transformers/modeling_tf_openai.py @@ -394,7 +394,7 @@ class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput): Base class for outputs of models predicting if two sentences are consecutive or not. Args: - lm_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): + logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). mc_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`): Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). @@ -411,7 +411,7 @@ class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput): heads. """ - lm_logits: tf.Tensor = None + logits: tf.Tensor = None mc_logits: tf.Tensor = None hidden_states: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None @@ -719,7 +719,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): return (lm_logits, mc_logits) + transformer_outputs[1:] return TFOpenAIGPTDoubleHeadsModelOutput( - lm_logits=lm_logits, + logits=lm_logits, mc_logits=mc_logits, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, diff --git a/src/transformers/modeling_tf_outputs.py b/src/transformers/modeling_tf_outputs.py index 8d61a17572..d0914b6ddf 100644 --- a/src/transformers/modeling_tf_outputs.py +++ b/src/transformers/modeling_tf_outputs.py @@ -113,13 +113,13 @@ class TFSeq2SeqModelOutput(ModelOutput): last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the decoder of the model. - If ``decoder_past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. - decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + If ``past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. + past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``decoder_past_key_values`` input) to speed up sequential decoding. + used (see ``past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -147,7 +147,7 @@ class TFSeq2SeqModelOutput(ModelOutput): """ last_hidden_state: tf.Tensor = None - decoder_past_key_values: Optional[List[tf.Tensor]] = None + past_key_values: Optional[List[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None encoder_last_hidden_state: Optional[tf.Tensor] = None @@ -259,12 +259,12 @@ class TFSeq2SeqLMOutput(ModelOutput): Languaged modeling loss. logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``decoder_past_key_values`` input) to speed up sequential decoding. + used (see ``past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -293,7 +293,7 @@ class TFSeq2SeqLMOutput(ModelOutput): loss: Optional[tf.Tensor] = None logits: tf.Tensor = None - decoder_past_key_values: Optional[List[tf.Tensor]] = None + past_key_values: Optional[List[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None encoder_last_hidden_state: Optional[tf.Tensor] = None @@ -366,12 +366,12 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput): Classification (or regression if config.num_labels==1) loss. logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). - decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``decoder_past_key_values`` input) to speed up sequential decoding. + used (see ``past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -400,7 +400,7 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput): loss: Optional[tf.Tensor] = None logits: tf.Tensor = None - decoder_past_key_values: Optional[List[tf.Tensor]] = None + past_key_values: Optional[List[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None encoder_last_hidden_state: Optional[tf.Tensor] = None @@ -512,12 +512,12 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput): Span-start scores (before SoftMax). end_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`): Span-end scores (before SoftMax). - decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be - used (see ``decoder_past_key_values`` input) to speed up sequential decoding. + used (see ``past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -547,7 +547,7 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput): loss: Optional[tf.Tensor] = None start_logits: tf.Tensor = None end_logits: tf.Tensor = None - decoder_past_key_values: Optional[List[tf.Tensor]] = None + past_key_values: Optional[List[tf.Tensor]] = None decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None decoder_attentions: Optional[Tuple[tf.Tensor]] = None encoder_last_hidden_state: Optional[tf.Tensor] = None diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index 6a4379c0f6..9b451c8ff2 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -437,15 +437,15 @@ class TFT5Block(tf.keras.layers.Layer): ): if past_key_value_state is not None: - assert self.is_decoder, "Only decoder can use `past_key_value_states`" - expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4 + assert self.is_decoder, "Only decoder can use `past_key_values`" + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format( - expected_num_past_key_value_states, - "2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "", + expected_num_past_key_values, + "2 (past / key) for cross attention" if expected_num_past_key_values == 4 else "", len(past_key_value_state), ) - assert len(past_key_value_state) == expected_num_past_key_value_states, error_message + assert len(past_key_value_state) == expected_num_past_key_values, error_message self_attn_past_key_value_state = past_key_value_state[:2] cross_attn_past_key_value_state = past_key_value_state[2:] @@ -586,11 +586,12 @@ class TFT5MainLayer(tf.keras.layers.Layer): encoder_attention_mask=None, inputs_embeds=None, head_mask=None, - past_key_value_states=None, + past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, training=False, + **kwargs, ): if isinstance(inputs, (tuple, list)): input_ids = inputs[0] @@ -599,7 +600,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): encoder_attention_mask = inputs[3] if len(inputs) > 3 else encoder_attention_mask inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds head_mask = inputs[5] if len(inputs) > 5 else head_mask - past_key_value_states = inputs[6] if len(inputs) > 6 else past_key_value_states + past_key_values = inputs[6] if len(inputs) > 6 else past_key_values use_cache = inputs[7] if len(inputs) > 7 else use_cache output_attentions = inputs[8] if len(inputs) > 8 else output_attentions output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states @@ -611,13 +612,26 @@ class TFT5MainLayer(tf.keras.layers.Layer): encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) head_mask = inputs.get("head_mask", head_mask) - past_key_value_states = inputs.get("past_key_value_states", past_key_value_states) + past_key_values = inputs.get("past_key_values", past_key_values) use_cache = inputs.get("use_cache", use_cache) output_attentions = inputs.get("output_attentions", output_attentions) output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) assert len(inputs) <= 10, "Too many inputs." + + if "past_key_value_states" in inputs: + warnings.warn( + "The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + past_key_values = inputs.pop("past_key_value_states") else: input_ids = inputs + if "past_key_value_states" in kwargs: + warnings.warn( + "The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + past_key_values = kwargs.pop("past_key_value_states") output_attentions = output_attentions if output_attentions is not None else self.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states @@ -639,13 +653,13 @@ class TFT5MainLayer(tf.keras.layers.Layer): batch_size, seq_length = input_shape - if past_key_value_states is not None: + if past_key_values is not None: assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format( input_shape, (batch_size, 1) ) # required mask seq length can be calculated via length of past # key value states and seq_length = 1 for the last token - mask_seq_length = shape_list(past_key_value_states[0][0])[2] + seq_length + mask_seq_length = shape_list(past_key_values[0][0])[2] + seq_length else: mask_seq_length = seq_length @@ -655,9 +669,9 @@ class TFT5MainLayer(tf.keras.layers.Layer): encoder_seq_length = shape_list(encoder_hidden_states)[1] encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1) - # initialize past_key_value_states with `None` if past does not exist - if past_key_value_states is None: - past_key_value_states = [None] * len(self.block) + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. @@ -677,7 +691,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): ) causal_mask = tf.cast(causal_mask, dtype=tf.float32) extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] - if past_key_value_states[0] is not None: + if past_key_values[0] is not None: extended_attention_mask = extended_attention_mask[:, :, -1:, :] else: extended_attention_mask = attention_mask[:, None, None, :] @@ -726,7 +740,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): hidden_states = self.dropout(inputs_embeds, training=training) - for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)): + for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -878,7 +892,7 @@ T5_INPUTS_DOCSTRING = r""" :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation. - If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`). + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: @@ -889,13 +903,13 @@ T5_INPUTS_DOCSTRING = r""" Used in the cross-attention of the decoder. decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. - decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + past_key_values (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains pre-computed key and value hidden-states of the attention blocks. Can be used to speed up decoding. - If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids` + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): - If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`). + If `use_cache` is True, `past_key_values` are returned and can be used to speed up decoding (see `past_key_values`). inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): Optionally, instead of passing :obj:`inputs` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `inputs` indices into associated vectors @@ -969,7 +983,7 @@ class TFT5Model(TFT5PreTrainedModel): encoder_outputs=None, inputs_embeds=None, head_mask=None, - decoder_past_key_value_states=None, + past_key_values=None, decoder_input_ids=None, decoder_attention_mask=None, decoder_inputs_embeds=None, @@ -978,6 +992,7 @@ class TFT5Model(TFT5PreTrainedModel): output_hidden_states=None, return_dict=None, training=False, + **kwargs, ): r""" Returns: @@ -999,7 +1014,7 @@ class TFT5Model(TFT5PreTrainedModel): encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds head_mask = inputs[4] if len(inputs) > 4 else head_mask - decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states + past_key_values = inputs[5] if len(inputs) > 5 else past_key_values decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds @@ -1017,7 +1032,7 @@ class TFT5Model(TFT5PreTrainedModel): encoder_outputs = inputs.get("encoder_outputs", encoder_outputs) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) head_mask = inputs.get("head_mask", head_mask) - decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states) + past_key_values = inputs.get("past_key_values", past_key_values) decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids) decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask) decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds) @@ -1026,9 +1041,23 @@ class TFT5Model(TFT5PreTrainedModel): output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) return_dict = inputs.get("return_dict", return_dict) assert len(inputs) <= 13, "Too many inputs." + + if "past_key_value_states" in inputs: + warnings.warn( + "The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + past_key_values = inputs.pop("past_key_value_states") else: input_ids = inputs + if "past_key_value_states" in kwargs: + warnings.warn( + "The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + past_key_values = kwargs.pop("past_key_value_states") + use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.return_dict @@ -1054,7 +1083,7 @@ class TFT5Model(TFT5PreTrainedModel): # If decoding with past key value states, only the last tokens # should be given as an input - if decoder_past_key_value_states is not None: + if past_key_values is not None: if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: @@ -1069,7 +1098,7 @@ class TFT5Model(TFT5PreTrainedModel): attention_mask, decoder_inputs_embeds, head_mask, - decoder_past_key_value_states, + past_key_values, use_cache, output_attentions, output_hidden_states, @@ -1103,7 +1132,7 @@ class TFT5Model(TFT5PreTrainedModel): return TFSeq2SeqModelOutput( last_hidden_state=decoder_outputs[0], - decoder_past_key_values=past, + past_key_values=past, decoder_hidden_states=decoder_outputs[2], decoder_attentions=decoder_outputs[3], encoder_last_hidden_state=encoder_outputs[0], @@ -1164,7 +1193,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling encoder_outputs=None, inputs_embeds=None, head_mask=None, - decoder_past_key_value_states=None, + past_key_values=None, decoder_input_ids=None, decoder_attention_mask=None, decoder_inputs_embeds=None, @@ -1174,6 +1203,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling return_dict=None, labels=None, training=False, + **kwargs, ): r""" labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): @@ -1204,7 +1234,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds head_mask = inputs[4] if len(inputs) > 4 else head_mask - decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states + past_key_values = inputs[5] if len(inputs) > 5 else past_key_values decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds @@ -1223,7 +1253,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling encoder_outputs = inputs.get("encoder_outputs", encoder_outputs) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) head_mask = inputs.get("head_mask", head_mask) - decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states) + past_key_values = inputs.get("past_key_values", past_key_values) decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids) decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask) decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds) @@ -1233,9 +1263,23 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling return_dict = inputs.get("return_dict", return_dict) labels = inputs.get("labels", labels) assert len(inputs) <= 14, "Too many inputs." + + if "past_key_value_states" in inputs: + warnings.warn( + "The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + past_key_values = inputs.pop("past_key_value_states") else: input_ids = inputs + if "past_key_value_states" in kwargs: + warnings.warn( + "The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", + FutureWarning, + ) + past_key_values = kwargs.pop("past_key_value_states") + use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.return_dict @@ -1266,7 +1310,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling # If decoding with past key value states, only the last tokens # should be given as an input - if decoder_past_key_value_states is not None: + if past_key_values is not None: if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: @@ -1281,7 +1325,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling attention_mask, decoder_inputs_embeds, head_mask, - decoder_past_key_value_states, + past_key_values, use_cache, output_attentions, output_hidden_states, @@ -1324,7 +1368,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling return TFSeq2SeqLMOutput( loss=loss, logits=logits, - decoder_past_key_values=past, + past_key_values=past, decoder_hidden_states=decoder_outputs[2], decoder_attentions=decoder_outputs[3], encoder_last_hidden_state=encoder_outputs[0], @@ -1337,14 +1381,14 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling # first step if len(past) < 2: - encoder_outputs, decoder_past_key_value_states = past, None + encoder_outputs, past_key_values = past, None else: - encoder_outputs, decoder_past_key_value_states = past[0], past[1] + encoder_outputs, past_key_values = past[0], past[1] return { "inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy "decoder_input_ids": inputs, # inputs are the decoder_input_ids - "decoder_past_key_value_states": decoder_past_key_value_states, + "past_key_values": past_key_values, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, "use_cache": use_cache, diff --git a/src/transformers/modeling_transfo_xl.py b/src/transformers/modeling_transfo_xl.py index c57be4afd3..9b0e276e2b 100644 --- a/src/transformers/modeling_transfo_xl.py +++ b/src/transformers/modeling_transfo_xl.py @@ -661,6 +661,15 @@ class TransfoXLLMHeadModelOutput(ModelOutput): hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None + @property + def logits(self): + # prediciton scores are the output of the adaptive softmax, see + # the file `modeling_transfo_xl_utilities`. Since the adaptive + # softmax returns the log softmax value, `self.prediciton_scores` + # are strictly speaking not exactly `logits`, but behave the same + # way logits do. + return self.prediction_scores + TRANSFO_XL_START_DOCSTRING = r""" diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index 3af9fbc9c7..8aefee7f85 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -33,6 +33,7 @@ if is_torch_available(): from transformers import ( BertLMHeadModel, BertModel, + BertTokenizer, EncoderDecoderConfig, EncoderDecoderModel, GPT2LMHeadModel, @@ -128,10 +129,11 @@ class EncoderDecoderMixin: decoder_config, decoder_input_ids, decoder_attention_mask, + return_dict, **kwargs ): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) - kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model} + kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict} enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( @@ -361,7 +363,11 @@ class EncoderDecoderMixin: def test_encoder_decoder_model_from_pretrained(self): input_ids_dict = self.prepare_config_and_inputs() - self.check_encoder_decoder_model_from_pretrained(**input_ids_dict) + self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=False) + + def test_encoder_decoder_model_from_pretrained_return_dict(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True) def test_save_and_load_from_pretrained(self): input_ids_dict = self.prepare_config_and_inputs() @@ -466,6 +472,22 @@ class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): "labels": decoder_token_labels, } + @slow + def test_bert2bert_summarization(self): + model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") + model.to(torch_device) + tokenizer = BertTokenizer.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") + + ARTICLE = """(CNN)Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members singing a racist chant. SAE's national chapter suspended the students, but University of Oklahoma President David Boren took it a step further, saying the university's affiliation with the fraternity is permanently done. The news is shocking, but it's not the first time SAE has faced controversy. SAE was founded March 9, 1856, at the University of Alabama, five years before the American Civil War, according to the fraternity website. When the war began, the group had fewer than 400 members, of which "369 went to war for the Confederate States and seven for the Union Army," the website says. The fraternity now boasts more than 200,000 living alumni, along with about 15,000 undergraduates populating 219 chapters and 20 "colonies" seeking full membership at universities. SAE has had to work hard to change recently after a string of member deaths, many blamed on the hazing of new recruits, SAE national President Bradley Cohen wrote in a message on the fraternity's website. The fraternity's website lists more than 130 chapters cited or suspended for "health and safety incidents" since 2010. At least 30 of the incidents involved hazing, and dozens more involved alcohol. However, the list is missing numerous incidents from recent months. Among them, according to various media outlets: Yale University banned the SAEs from campus activities last month after members allegedly tried to interfere with a sexual misconduct investigation connected to an initiation rite. Stanford University in December suspended SAE housing privileges after finding sorority members attending a fraternity function were subjected to graphic sexual content. And Johns Hopkins University in November suspended the fraternity for underage drinking. "The media has labeled us as the 'nation's deadliest fraternity,' " Cohen said. In 2011, for example, a student died while being coerced into excessive alcohol consumption, according to a lawsuit. SAE's previous insurer dumped the fraternity. "As a result, we are paying Lloyd's of London the highest insurance rates in the Greek-letter world," Cohen said. Universities have turned down SAE's attempts to open new chapters, and the fraternity had to close 12 in 18 months over hazing incidents.""" + + EXPECTED_SUMMARY = """sae was founded in 1856, five years before the civil war. the fraternity has had to work hard to change recently. the university of oklahoma president says the university's affiliation with the fraternity is permanently done. the sae has had a string of members in recent months.""" + + input_ids = tokenizer(ARTICLE, return_tensors="pt").input_ids.to(torch_device) + output_ids = model.generate(input_ids) + summary = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + self.assertEqual(summary, EXPECTED_SUMMARY) + class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): def get_encoder_decoder_model(self, config, decoder_config): diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 17e0a6bc48..dcb0faefe4 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -289,9 +289,9 @@ class GPT2ModelTester: } result = model(**inputs) - self.parent.assertEqual(result.lm_loss.shape, ()) + self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual( - result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size) + result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size) ) self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices)) @@ -324,7 +324,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () all_generative_model_classes = ( - (GPT2LMHeadModel,) if is_torch_available() else () + (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () ) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly test_missing_keys = False diff --git a/tests/test_modeling_openai.py b/tests/test_modeling_openai.py index 1014e1eea4..92a0335cda 100644 --- a/tests/test_modeling_openai.py +++ b/tests/test_modeling_openai.py @@ -131,8 +131,8 @@ class OpenAIGPTModelTester: model.eval() result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) - self.parent.assertEqual(result.lm_loss.shape, ()) - self.parent.assertEqual(result.lm_logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index fef623807c..4c411d8449 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -159,17 +159,15 @@ class T5ModelTester: ) result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) decoder_output = result.last_hidden_state - decoder_past = result.decoder_past_key_values + decoder_past = result.past_key_values encoder_output = result.encoder_last_hidden_state self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size)) - self.parent.assertEqual(len(decoder_past), 2) - self.parent.assertTrue(torch.all(decoder_past[0][0] == encoder_output)) - # There should be `num_layers` key value embeddings stored in decoder_past[1] - self.parent.assertEqual(len(decoder_past[1]), config.num_layers) - # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple - self.parent.assertEqual(len(decoder_past[1][0]), 4) + # There should be `num_layers` key value embeddings stored in decoder_past + self.parent.assertEqual(len(decoder_past), config.num_layers) + # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple + self.parent.assertEqual(len(decoder_past[0]), 4) def create_and_check_with_lm_head( self, diff --git a/tests/test_modeling_tf_gpt2.py b/tests/test_modeling_tf_gpt2.py index 41b973719e..4cd20be25e 100644 --- a/tests/test_modeling_tf_gpt2.py +++ b/tests/test_modeling_tf_gpt2.py @@ -238,7 +238,7 @@ class TFGPT2ModelTester: } result = model(inputs) self.parent.assertEqual( - result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size) + result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size) ) self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices)) diff --git a/tests/test_modeling_tf_openai.py b/tests/test_modeling_tf_openai.py index e3bd82dae2..6e57db2d39 100644 --- a/tests/test_modeling_tf_openai.py +++ b/tests/test_modeling_tf_openai.py @@ -151,7 +151,7 @@ class TFOpenAIGPTModelTester: } result = model(inputs) self.parent.assertEqual( - result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size) + result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size) ) self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices)) diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index eb575f5131..7c50bd15c5 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -96,7 +96,7 @@ class TFT5ModelTester: result = model(input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids) decoder_output = result.last_hidden_state - decoder_past = result.decoder_past_key_values + decoder_past = result.past_key_values encoder_output = result.encoder_last_hidden_state self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size]) self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])