diff --git a/docs/source/internal/generation_utils.rst b/docs/source/internal/generation_utils.rst index f645472ffa..64ebd17b9f 100644 --- a/docs/source/internal/generation_utils.rst +++ b/docs/source/internal/generation_utils.rst @@ -13,13 +13,102 @@ Utilities for Generation ----------------------------------------------------------------------------------------------------------------------- -This page lists all the utility functions used by :meth:`~transformers.PretrainedModel.generate`, -:meth:`~transformers.PretrainedModel.greedy_search`, :meth:`~transformers.PretrainedModel.sample`, -:meth:`~transformers.PretrainedModel.beam_search`, :meth:`~transformers.PretrainedModel.beam_sample`, and -:meth:`~transformers.PretrainedModel.group_beam_search`. +This page lists all the utility functions used by :meth:`~transformers.PreTrainedModel.generate`, +:meth:`~transformers.PreTrainedModel.greedy_search`, :meth:`~transformers.PreTrainedModel.sample`, +:meth:`~transformers.PreTrainedModel.beam_search`, :meth:`~transformers.PreTrainedModel.beam_sample`, and +:meth:`~transformers.PreTrainedModel.group_beam_search`. Most of those are only useful if you are studying the code of the generate methods in the library. +Generate Outputs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The output of :meth:`~transformers.PreTrainedModel.generate` is an instance of a subclass of +:class:`~transformers.file_utils.ModelOutput`. This output is a data structure containing all the information returned +by :meth:`~transformers.PreTrainedModel.generate`, but that can also be used as tuple or dictionary. + +Here's an example: + +.. code-block:: + + from transformers import GPT2Tokenizer, GPT2LMHeadModel + + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + model = GPT2LMHeadModel.from_pretrained('gpt2') + + inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt") + generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) + +The ``generation_output`` object is a :class:`~transformers.generation_utils.GreedySearchDecoderOnlyOutput`, as we can +see in the documentation of that class below, it means it has the following attributes: + +- ``sequences``: the generated sequences of tokens +- ``scores`` (optional): the prediction scores of the language modelling head, for each generation step +- ``hidden_states`` (optional): the hidden states of the model, for each generation step +- ``attentions`` (optional): the attention weights of the model, for each generation step + +Here we have the ``scores`` since we passed along ``output_scores=True``, but we don't have ``hidden_states`` and +``attentions`` because we didn't pass ``output_hidden_states=True`` or ``output_attentions=True``. + +You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you +will get ``None``. Here for instance ``generation_output.scores`` are all the generated prediction scores of the +language modeling head, and ``generation_output.attentions`` is ``None``. + +When using our ``generation_output`` object as a tuple, it only keeps the attributes that don't have ``None`` values. +Here, for instance, it has two elements, ``loss`` then ``logits``, so + +.. code-block:: + + generation_output[:2] + +will return the tuple ``(generation_output.sequences, generation_output.scores)`` for instance. + +When using our ``generation_output`` object as a dictionary, it only keeps the attributes that don't have ``None`` +values. Here, for instance, it has two keys that are ``sequences`` and ``scores``. + +We document here all output types. + + +GreedySearchOutput +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: transformers.generation_utils.GreedySearchDecoderOnlyOutput + :members: + +.. autoclass:: transformers.generation_utils.GreedySearchEncoderDecoderOutput + :members: + + +SampleOutput +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: transformers.generation_utils.SampleDecoderOnlyOutput + :members: + +.. autoclass:: transformers.generation_utils.SampleEncoderDecoderOutput + :members: + + +BeamSearchOutput +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: transformers.generation_utils.BeamSearchDecoderOnlyOutput + :members: + +.. autoclass:: transformers.generation_utils.BeamSearchEncoderDecoderOutput + :members: + + +BeamSampleOutput +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: transformers.generation_utils.BeamSampleDecoderOnlyOutput + :members: + +.. autoclass:: transformers.generation_utils.BeamSampleEncoderDecoderOutput + :members: + + LogitsProcessor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index ba53a860cb..eeb8563cbe 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -124,6 +124,11 @@ class PretrainedConfig(object): - **num_return_sequences** (:obj:`int`, `optional`, defaults to 1) -- Number of independently computed returned sequences for each element in the batch that will be used by default in the :obj:`generate` method of the model. + - **output_scores** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether the model should return the + logits when used for generation + - **return_dict_in_generate** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether the model should + return a :class:`~transformers.file_utils.ModelOutput` instead of a :obj:`torch.LongTensor` + Parameters for fine-tuning tasks @@ -203,6 +208,8 @@ class PretrainedConfig(object): self.bad_words_ids = kwargs.pop("bad_words_ids", None) self.num_return_sequences = kwargs.pop("num_return_sequences", 1) self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0) + self.output_scores = kwargs.pop("output_scores", False) + self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False) # Fine-tuning task arguments self.architectures = kwargs.pop("architectures", None) @@ -343,6 +350,7 @@ class PretrainedConfig(object): Passing :obj:`use_auth_token=True` is required when you want to use a private model. + Returns: :class:`PretrainedConfig`: The configuration object instantiated from this pretrained model. @@ -372,6 +380,8 @@ class PretrainedConfig(object): From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a :class:`~transformers.PretrainedConfig` using ``from_dict``. + + Parameters: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index b04c93d567..a2e2cb4753 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -281,7 +281,7 @@ class BeamSearchScorer(BeamScorer): final_beam_indices: torch.LongTensor, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, - ) -> torch.LongTensor: + ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) # finalize all open beam hypotheses and add to generated hypotheses @@ -300,14 +300,20 @@ class BeamSearchScorer(BeamScorer): # select the best hypotheses sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) best = [] + best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) # retrieve best hypotheses for i, beam_hyp in enumerate(self._beam_hyps): sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) for j in range(self.num_beam_hyps_to_keep): - best_hyp = sorted_hyps.pop()[1] + best_hyp_tuple = sorted_hyps.pop() + best_score = best_hyp_tuple[0] + best_hyp = best_hyp_tuple[1] sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) + + # append to lists best.append(best_hyp) + best_scores[i * self.num_beam_hyps_to_keep + j] = best_score # prepare for adding eos sent_max_len = min(sent_lengths.max().item() + 1, self.max_length) @@ -322,7 +328,12 @@ class BeamSearchScorer(BeamScorer): decoded[i, : sent_lengths[i]] = hypo if sent_lengths[i] < self.max_length: decoded[i, sent_lengths[i]] = eos_token_id - return decoded + return UserDict( + { + "sequences": decoded, + "sequence_scores": best_scores, + } + ) class BeamHypotheses: diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 4c2f20f040..fb55b97859 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from dataclasses import dataclass +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch from torch.nn import functional as F @@ -39,6 +40,299 @@ from .utils import logging logger = logging.get_logger(__name__) +@dataclass +class GreedySearchDecoderOnlyOutput(ModelOutput): + """ + Base class for outputs of decoder-only generation models using greedy search. + + + Args: + sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or + shorter if all batches finished early due to the :obj:`eos_token_id`. + scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of + shape :obj:`(batch_size, config.vocab_size)`). + attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. + hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class GreedySearchEncoderDecoderOutput(ModelOutput): + """ + Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention + weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the + encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + + + Args: + sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or + shorter if all batches finished early due to the :obj:`eos_token_id`. + scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of + shape :obj:`(batch_size, config.vocab_size)`). + encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size, + num_heads, sequence_length, sequence_length)`. + encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. + decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class SampleDecoderOnlyOutput(ModelOutput): + """ + Base class for outputs of decoder-only generation models using sampling. + + + Args: + sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or + shorter if all batches finished early due to the :obj:`eos_token_id`. + scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of + shape :obj:`(batch_size * num_return_sequences, config.vocab_size)`). + attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(num_return_sequences * batch_size, num_heads, generated_length, + sequence_length)`. + hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(num_return_sequences * batch_size, generated_length, + hidden_size)`. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class SampleEncoderDecoderOutput(ModelOutput): + """ + Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of + the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states + attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + + + Args: + sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or + shorter if all batches finished early due to the :obj:`eos_token_id`. + scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of + shape :obj:`(batch_size * num_return_sequences, config.vocab_size)`). + encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size * + num_return_sequences, num_heads, sequence_length, sequence_length)`. + encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size * num_return_sequences, sequence_length, hidden_size)`. + decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_return_sequences, num_heads, generated_length, + sequence_length)`. + decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_return_sequences, generated_length, + hidden_size)`. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class BeamSearchDecoderOnlyOutput(ModelOutput): + """ + Base class for outputs of decoder-only generation models using beam search. + + Args: + sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or + shorter if all batches finished early due to the :obj:`eos_token_id`. + sequences_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_return_sequences)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): + Final beam scores of the generated ``sequences``. + scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): + Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log + softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam + . :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size + * num_beams * num_return_sequences, config.vocab_size)`). + attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams, num_heads, generated_length, + sequence_length)`. + hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams * num_return_sequences, generated_length, + hidden_size)`. + """ + + sequences: torch.LongTensor = None + sequences_scores: Optional[torch.FloatTensor] = None + scores: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class BeamSearchEncoderDecoderOutput(ModelOutput): + """ + Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights + of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states + attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + + Args: + sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or + shorter if all batches finished early due to the :obj:`eos_token_id`. + sequences_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_return_sequences)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): + Final beam scores of the generated ``sequences``. + scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): + Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log + softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam + . :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size + * num_beams, config.vocab_size)`). + attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size, + num_heads, sequence_length, sequence_length)`. + encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size * num_beams * num_return_sequences, sequence_length, hidden_size)`. + decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams * num_return_sequences, num_heads, + generated_length, sequence_length)`. + decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams * num_return_sequences, generated_length, + hidden_size)`. + """ + + sequences: torch.LongTensor = None + sequences_scores: Optional[torch.FloatTensor] = None + scores: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class BeamSampleDecoderOnlyOutput(ModelOutput): + """ + Base class for outputs of decoder-only generation models using beam sample. + + Args: + sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or + shorter if all batches finished early due to the :obj:`eos_token_id`. + sequences_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_return_sequence)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): + Final beam scores of the generated ``sequences``. + scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): + Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log + softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam + . :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size + * num_beams * num_return_sequences, config.vocab_size)`). + attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams, num_heads, generated_length, + sequence_length)`. + hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + sequences_scores: Optional[torch.FloatTensor] = None + scores: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +@dataclass +class BeamSampleEncoderDecoderOutput(ModelOutput): + """ + Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention + weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the + encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + + Args: + sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or + shorter if all batches finished early due to the :obj:`eos_token_id`. + sequences_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_return_sequence)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): + Final beam scores of the generated ``sequences``. + scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``): + Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log + softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam + . :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size + * num_beams, config.vocab_size)`). + encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size, + num_heads, sequence_length, sequence_length)`. + encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size * num_beams, sequence_length, hidden_size)`. + decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams, num_heads, generated_length, + sequence_length)`. + decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams, generated_length, hidden_size)`. + """ + + sequences: torch.LongTensor = None + sequences_scores: Optional[torch.FloatTensor] = None + scores: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] +SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] +BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] +BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] + + class GenerationMixin: """ A class containing all of the functions supporting generation, to be used as a mixin in @@ -139,7 +433,7 @@ class GenerationMixin: is_encoder_decoder: bool = False, attention_mask: torch.LongTensor = None, encoder_outputs: ModelOutput = None, - **model_kwargs + **model_kwargs, ) -> Tuple[torch.LongTensor, Dict[str, Any]]: expanded_return_idx = ( torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) @@ -327,8 +621,12 @@ class GenerationMixin: num_beam_groups: Optional[int] = None, diversity_penalty: Optional[float] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - **model_kwargs - ) -> torch.LongTensor: + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + **model_kwargs, + ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: r""" Generates sequences for models with a language modeling head. The method currently supports greedy decoding, multinomial sampling, beam-search decoding, and beam-search multinomial sampling. @@ -407,18 +705,44 @@ class GenerationMixin: conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This argument is useful for constrained generation conditioned on the prefix, as described in `Autoregressive Entity Retrieval `__. + output_attentions (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more details. + output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors + for more details. + output_scores (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. + return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the - model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific + model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with `decoder_`. Return: - :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated - sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all - batches finished early due to the :obj:`eos_token_id`. + :class:`~transformers.file_utils.ModelOutput` or :obj:`torch.LongTensor`: A + :class:`~transformers.file_utils.ModelOutput` (if ``return_dict_in_generate=True`` or when + ``config.return_dict_in_generate=True``) or a :obj:`torch.FloatTensor`. + + If the model is `not` an encoder-decoder model (``model.config.is_encoder_decoder=False``), the + possible :class:`~transformers.file_utils.ModelOutput` types are: + + - :class:`~transformers.generation_utils.GreedySearchDecoderOnlyOutput`, + - :class:`~transformers.generation_utils.SampleDecoderOnlyOutput`, + - :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput`, + - :class:`~transformers.generation_utils.BeamSampleDecoderOnlyOutput` + + If the model is an encoder-decoder model (``model.config.is_encoder_decoder=True``), the possible + :class:`~transformers.file_utils.ModelOutput` types are: + + - :class:`~transformers.generation_utils.GreedySearchEncoderDecoderOutput`, + - :class:`~transformers.generation_utils.SampleEncoderDecoderOutput`, + - :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput`, + - :class:`~transformers.generation_utils.BeamSampleEncoderDecoderOutput` Examples:: - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") @@ -483,6 +807,18 @@ class GenerationMixin: bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + model_kwargs["output_attentions"] = output_attentions + model_kwargs["output_hidden_states"] = output_hidden_states + if input_ids is None: # init `input_ids` with bos_token_id input_ids = self._prepare_input_ids_for_generation(bos_token_id) @@ -552,6 +888,8 @@ class GenerationMixin: max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, **model_kwargs, ) @@ -577,6 +915,8 @@ class GenerationMixin: max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, **model_kwargs, ) @@ -609,6 +949,8 @@ class GenerationMixin: max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, **model_kwargs, ) @@ -645,6 +987,8 @@ class GenerationMixin: max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, **model_kwargs, ) @@ -681,6 +1025,8 @@ class GenerationMixin: max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, **model_kwargs, ) @@ -691,11 +1037,17 @@ class GenerationMixin: max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, - **model_kwargs - ): + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + **model_kwargs, + ) -> Union[GreedySearchOutput, torch.LongTensor]: r""" Generates sequences for models with a language modeling head using greedy decoding. + + Parameters: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): @@ -711,14 +1063,29 @@ class GenerationMixin: The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. + output_attentions (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more details. + output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors + for more details. + output_scores (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. + return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + model_kwargs: Additional model specific keyword arguments will be forwarded to the :obj:`forward` function of the model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. Return: - :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated - sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all - batches finished early due to the :obj:`eos_token_id`. + :class:`~transformers.generation_utils.GreedySearchDecoderOnlyOutput`, + :class:`~transformers.generation_utils.GreedySearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A + :obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a + :class:`~transformers.generation_utils.GreedySearchDecoderOnlyOutput` if + ``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a + :class:`~transformers.generation_utils.GreedySearchEncoderDecoderOutput` if + ``model.config.is_encoder_decoder=True``. Examples:: @@ -747,12 +1114,31 @@ class GenerationMixin: >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) """ - # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() max_length = max_length if max_length is not None else self.config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) # init sequence length tensors sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation( @@ -764,14 +1150,35 @@ class GenerationMixin: model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token - outputs = self(**model_inputs, return_dict=True) + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) next_token_logits = outputs.logits[:, -1, :] + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + # pre-process distribution - scores = logits_processor(input_ids, next_token_logits) + next_tokens_scores = logits_processor(input_ids, next_token_logits) # argmax - next_tokens = torch.argmax(scores, dim=-1) + next_tokens = torch.argmax(next_tokens_scores, dim=-1) # add code that transfomers next_tokens to tokens_to_add if eos_token_id is not None: @@ -799,7 +1206,25 @@ class GenerationMixin: # increase cur_len cur_len = cur_len + 1 - return input_ids + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GreedySearchEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return GreedySearchDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return input_ids def sample( self, @@ -809,8 +1234,12 @@ class GenerationMixin: max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, - **model_kwargs - ): + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + **model_kwargs, + ) -> Union[SampleOutput, torch.LongTensor]: r""" Generates sequences for models with a language modeling head using multinomial sampling. @@ -833,14 +1262,28 @@ class GenerationMixin: The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. + output_attentions (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more details. + output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors + for more details. + output_scores (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. + return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. Return: - :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated - sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all - batches finished early due to the :obj:`eos_token_id`. + :class:`~transformers.generation_utils.SampleDecoderOnlyOutput`, + :class:`~transformers.generation_utils.SampleEncoderDecoderOutput` or obj:`torch.LongTensor`: A + :obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a + :class:`~transformers.generation_utils.SampleDecoderOnlyOutput` if + ``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a + :class:`~transformers.generation_utils.SampleEncoderDecoderOutput` if + ``model.config.is_encoder_decoder=True``. Examples:: @@ -883,6 +1326,26 @@ class GenerationMixin: max_length = max_length if max_length is not None else self.config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) # init sequence length tensors sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation( @@ -895,15 +1358,36 @@ class GenerationMixin: model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token - outputs = self(**model_inputs, return_dict=True) + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) next_token_logits = outputs.logits[:, -1, :] # pre-process distribution - scores = logits_processor(input_ids, next_token_logits) - scores = logits_warper(input_ids, scores) + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) # sample - probs = F.softmax(scores, dim=-1) + probs = F.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # add code that transfomers next_tokens to tokens_to_add @@ -930,7 +1414,25 @@ class GenerationMixin: outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) - return input_ids + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return SampleEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return SampleDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return input_ids def beam_search( self, @@ -940,8 +1442,12 @@ class GenerationMixin: max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, - **model_kwargs - ): + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + **model_kwargs, + ) -> Union[BeamSearchOutput, torch.LongTensor]: r""" Generates sequences for models with a language modeling head using beam search decoding. @@ -964,14 +1470,29 @@ class GenerationMixin: The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. + output_attentions (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more details. + output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors + for more details. + output_scores (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. + return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. Return: - :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated - sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all - batches finished early due to the :obj:`eos_token_id`. + :class:`~transformers.generation_utilsBeamSearchDecoderOnlyOutput`, + :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A + :obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a + :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if + ``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a + :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if + ``model.config.is_encoder_decoder=True``. + Examples:: @@ -1025,6 +1546,26 @@ class GenerationMixin: max_length = max_length if max_length is not None else self.config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams @@ -1042,7 +1583,12 @@ class GenerationMixin: while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - outputs = self(**model_inputs, return_dict=True) + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) next_token_logits = outputs.logits[:, -1, :] # adjust tokens for Bart, *e.g.* @@ -1054,6 +1600,23 @@ class GenerationMixin: next_token_scores = logits_processor(input_ids, next_token_scores) next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) @@ -1090,11 +1653,33 @@ class GenerationMixin: if beam_scorer.is_done: break - decoded = beam_scorer.finalize( + sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id ) - return decoded + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + if self.config.is_encoder_decoder: + return BeamSearchEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return BeamSearchDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return sequence_outputs["sequences"] def beam_sample( self, @@ -1105,8 +1690,12 @@ class GenerationMixin: max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, - **model_kwargs - ): + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + **model_kwargs, + ) -> Union[BeamSampleOutput, torch.LongTensor]: r""" Generates sequences for models with a language modeling head using beam search with multinomial sampling. @@ -1133,14 +1722,28 @@ class GenerationMixin: The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. + output_attentions (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more details. + output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors + for more details. + output_scores (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. + return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. Return: - :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated - sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all - batches finished early due to the :obj:`eos_token_id`. + :class:`~transformers.generation_utils.BeamSampleDecoderOnlyOutput`, + :class:`~transformers.generation_utils.BeamSampleEncoderDecoderOutput` or obj:`torch.LongTensor`: A + :obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a + :class:`~transformers.generation_utils.BeamSampleDecoderOnlyOutput` if + ``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a + :class:`~transformers.generation_utils.BeamSampleEncoderDecoderOutput` if + ``model.config.is_encoder_decoder=True``. Examples:: @@ -1202,6 +1805,26 @@ class GenerationMixin: max_length = max_length if max_length is not None else self.config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams @@ -1214,7 +1837,12 @@ class GenerationMixin: while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - outputs = self(**model_inputs, return_dict=True) + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) next_token_logits = outputs.logits[:, -1, :] # adjust token scores (a no-op by default) @@ -1228,6 +1856,22 @@ class GenerationMixin: next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) next_token_scores = logits_warper(input_ids, next_token_scores) + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) @@ -1267,11 +1911,33 @@ class GenerationMixin: if beam_scorer.is_done: break - decoded = beam_scorer.finalize( + sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id ) - return decoded + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + if self.config.is_encoder_decoder: + return BeamSearchEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return BeamSearchDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return sequence_outputs["sequences"] def group_beam_search( self, @@ -1281,7 +1947,11 @@ class GenerationMixin: max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, - **model_kwargs + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + **model_kwargs, ): r""" Generates sequences for models with a language modeling head using beam search decoding. @@ -1305,14 +1975,29 @@ class GenerationMixin: The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. + output_attentions (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more details. + output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors + for more details. + output_scores (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. + return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. model_kwargs: Additional model specific kwargs that will be forwarded to the :obj:`forward` function of the model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. Return: - :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated - sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all - batches finished early due to the :obj:`eos_token_id`. + :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput`, + :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A + :obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a + :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if + :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if + ``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a + :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if + ``model.config.is_encoder_decoder=True``. Examples:: @@ -1369,6 +2054,26 @@ class GenerationMixin: max_length = max_length if max_length is not None else self.config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams @@ -1397,7 +2102,12 @@ class GenerationMixin: # do one decoder step on all beams of all sentences in batch model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - outputs = self(**model_inputs, return_dict=True) + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) for beam_group_idx in range(num_beam_groups): group_start_idx = beam_group_idx * num_sub_beams @@ -1406,6 +2116,10 @@ class GenerationMixin: # indices of beams of current group among all sentences in batch batch_group_indices = [] + + if output_scores: + processed_score = torch.zeros_like(outputs.logits[:, -1, :]) + for batch_idx in range(batch_size): batch_group_indices.extend( [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] @@ -1429,8 +2143,11 @@ class GenerationMixin: next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as( next_token_scores ) - # reshape for beam search + if output_scores: + processed_score[batch_group_indices] = next_token_scores + + # reshape for beam search next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) next_token_scores, next_tokens = torch.topk( @@ -1463,6 +2180,22 @@ class GenerationMixin: num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size) ) + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (processed_score,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) @@ -1474,11 +2207,33 @@ class GenerationMixin: if beam_scorer.is_done: break - decoded = beam_scorer.finalize( + sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id ) - return decoded + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] + if self.config.is_encoder_decoder: + return BeamSearchEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return BeamSearchDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return sequence_outputs["sequences"] def top_k_top_p_filtering( diff --git a/src/transformers/models/openai/configuration_openai.py b/src/transformers/models/openai/configuration_openai.py index b2072df4cf..1e7bf8ec8c 100644 --- a/src/transformers/models/openai/configuration_openai.py +++ b/src/transformers/models/openai/configuration_openai.py @@ -136,7 +136,6 @@ class OpenAIGPTConfig(PretrainedConfig): summary_activation=None, summary_proj_to_labels=True, summary_first_dropout=0.1, - use_cache=True, **kwargs ): super().__init__(**kwargs) @@ -159,7 +158,6 @@ class OpenAIGPTConfig(PretrainedConfig): self.summary_activation = summary_activation self.summary_first_dropout = summary_first_dropout self.summary_proj_to_labels = summary_proj_to_labels - self.use_cache = use_cache @property def max_position_embeddings(self): diff --git a/tests/test_generation_beam_search.py b/tests/test_generation_beam_search.py index 10a932395f..aa8270c31f 100644 --- a/tests/test_generation_beam_search.py +++ b/tests/test_generation_beam_search.py @@ -190,7 +190,7 @@ class BeamSearchTester: input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1) # finalize - decoded = beam_scorer.finalize( + sequence_output = beam_scorer.finalize( input_ids, output_scores, output_tokens, @@ -198,19 +198,27 @@ class BeamSearchTester: pad_token_id=self.pad_token_id, eos_token_id=self.eos_token_id, ) + + sequences = sequence_output["sequences"] + sequence_scores = sequence_output["sequence_scores"] + # since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length` - self.parent.assertListEqual(list(decoded.shape), [self.batch_size, max_length]) + self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length]) + self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size]) + + # check sequence_scores + self.parent.assertFalse((sequence_scores > 0).any().item()) # first batch has to finish with eos_token - self.parent.assertEqual(decoded[0, -1].item(), self.eos_token_id) + self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id) # other batches cannot finish with eos token - self.parent.assertNotEqual(decoded[1, -1].item(), self.eos_token_id) - self.parent.assertNotEqual(decoded[2, -1].item(), self.eos_token_id) + self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id) + self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id) # now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned beam_scorer.num_beam_hyps_to_keep = self.num_beams - decoded = beam_scorer.finalize( + sequence_output = beam_scorer.finalize( input_ids, output_scores, output_tokens, @@ -218,7 +226,11 @@ class BeamSearchTester: pad_token_id=self.pad_token_id, eos_token_id=self.eos_token_id, ) - self.parent.assertListEqual(list(decoded.shape), [self.num_beams * self.batch_size, max_length]) + sequences = sequence_output["sequences"] + sequence_scores = sequence_output["sequence_scores"] + + self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length]) + self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size]) @require_torch diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index ce0fe08fe0..5359f348f8 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -36,6 +36,14 @@ if is_torch_available(): TopKLogitsWarper, TopPLogitsWarper, ) + from transformers.generation_utils import ( + BeamSearchDecoderOnlyOutput, + BeamSearchEncoderDecoderOutput, + GreedySearchDecoderOnlyOutput, + GreedySearchEncoderDecoderOutput, + SampleDecoderOnlyOutput, + SampleEncoderDecoderOutput, + ) class GenerationTesterMixin: @@ -146,9 +154,16 @@ class GenerationTesterMixin: return beam_kwargs, beam_scorer @staticmethod - def _get_encoder_outputs(model, input_ids, attention_mask, num_interleave=1): + def _get_encoder_outputs( + model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1 + ): encoder = model.get_encoder() - encoder_outputs = encoder(input_ids, attention_mask=attention_mask) + encoder_outputs = encoder( + input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave( num_interleave, dim=0 ) @@ -156,181 +171,480 @@ class GenerationTesterMixin: attention_mask = None return encoder_outputs, input_ids, attention_mask + def _greedy_generate( + self, + model, + input_ids, + attention_mask, + max_length, + output_scores=False, + output_attentions=False, + output_hidden_states=False, + return_dict_in_generate=False, + ): + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], model.config.eos_token_id + ) + + kwargs = {} + if model.config.is_encoder_decoder: + max_length = 4 + + output_generate = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=False, + num_beams=1, + max_length=max_length, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + **logits_process_kwargs, + ) + + if model.config.is_encoder_decoder: + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( + model, + input_ids, + attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + kwargs["encoder_outputs"] = encoder_outputs + + with torch.no_grad(): + output_greedy = model.greedy_search( + input_ids, + max_length=max_length, + attention_mask=attention_mask, + logits_processor=logits_processor, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + **kwargs, + ) + return output_greedy, output_generate + + def _sample_generate( + self, + model, + input_ids, + attention_mask, + max_length, + num_return_sequences, + logits_processor, + logits_warper, + logits_warper_kwargs, + process_kwargs, + output_scores=False, + output_attentions=False, + output_hidden_states=False, + return_dict_in_generate=False, + ): + torch.manual_seed(0) + output_generate = model.generate( + input_ids, + do_sample=True, + num_beams=1, + max_length=max_length, + num_return_sequences=num_return_sequences, + attention_mask=attention_mask, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + **logits_warper_kwargs, + **process_kwargs, + ) + + torch.manual_seed(0) + kwargs = {} + if model.config.is_encoder_decoder: + encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + model, + input_ids, + attention_mask, + num_interleave=num_return_sequences, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + kwargs["encoder_outputs"] = encoder_outputs + input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0) + else: + attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0) + input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0) + + with torch.no_grad(): + output_sample = model.sample( + input_ids_clone, + attention_mask=attention_mask_clone, + max_length=max_length, + logits_processor=logits_processor, + logits_warper=logits_warper, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + **kwargs, + ) + return output_sample, output_generate + + def _beam_search_generate( + self, + model, + input_ids, + attention_mask, + max_length, + beam_scorer, + beam_kwargs, + logits_processor, + logits_process_kwargs, + output_scores=False, + output_attentions=False, + output_hidden_states=False, + return_dict_in_generate=False, + ): + output_generate = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=False, + max_length=max_length, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + **beam_kwargs, + **logits_process_kwargs, + ) + + # beam_search does not automatically interleave `batch_size` dim for `num_beams` + kwargs = {} + if model.config.is_encoder_decoder: + encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + model, + input_ids, + attention_mask, + num_interleave=beam_scorer.num_beams, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + kwargs["encoder_outputs"] = encoder_outputs + input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) + else: + attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) + input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) + + with torch.no_grad(): + output_beam_search = model.beam_search( + input_ids_clone, + beam_scorer, + max_length=max_length, + attention_mask=attention_mask_clone, + logits_processor=logits_processor, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + **kwargs, + ) + return output_generate, output_beam_search + + def _beam_sample_generate( + self, + model, + input_ids, + attention_mask, + max_length, + num_return_sequences, + beam_scorer, + beam_kwargs, + logits_warper, + logits_warper_kwargs, + output_scores=False, + output_attentions=False, + output_hidden_states=False, + return_dict_in_generate=False, + ): + torch.manual_seed(0) + output_generate = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=True, + max_length=max_length, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + **beam_kwargs, + **logits_warper_kwargs, + ) + # beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences` + kwargs = {} + if model.config.is_encoder_decoder: + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( + model, + input_ids, + attention_mask, + num_interleave=beam_scorer.num_beams * num_return_sequences, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + kwargs["encoder_outputs"] = encoder_outputs + else: + attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0) + + torch.manual_seed(0) + with torch.no_grad(): + output_beam_sample = model.beam_sample( + input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0), + beam_scorer, + max_length=max_length, + attention_mask=attention_mask, + logits_warper=logits_warper, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + **kwargs, + ) + + return output_generate, output_beam_sample + + def _group_beam_search_generate( + self, + model, + input_ids, + attention_mask, + max_length, + beam_scorer, + beam_kwargs, + logits_processor, + logits_process_kwargs, + output_scores=False, + output_attentions=False, + output_hidden_states=False, + return_dict_in_generate=False, + ): + output_generate = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=False, + max_length=max_length, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + **beam_kwargs, + **logits_process_kwargs, + ) + + # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` + kwargs = {} + if model.config.is_encoder_decoder: + encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + model, + input_ids, + attention_mask, + num_interleave=beam_scorer.num_beams, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + kwargs["encoder_outputs"] = encoder_outputs + input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) + else: + attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) + input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) + + with torch.no_grad(): + output_group_beam_search = model.group_beam_search( + input_ids_clone, + beam_scorer, + max_length=max_length, + attention_mask=attention_mask_clone, + logits_processor=logits_processor, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + **kwargs, + ) + return output_generate, output_group_beam_search + def test_greedy_generate(self): + # check `generate()` and `greedy_search()` are equal for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], config.eos_token_id + # test old generation output for backwards compatibility + model = model_class(config).to(torch_device).eval() + output_greedy, output_generate = self._greedy_generate( + model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length ) + self.assertListEqual(output_greedy.tolist(), output_generate.tolist()) - model = model_class(config).to(torch_device) - model.eval() - - # check `generate()` and `greedy_search()` are equal - kwargs = {} - if model.config.is_encoder_decoder: - max_length = 4 - - output_ids_generate = model.generate( - input_ids, + def test_greedy_generate_dict_outputs(self): + for model_class in self.all_generative_model_classes: + # disable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config.use_cache = False + model = model_class(config).to(torch_device).eval() + output_greedy, output_generate = self._greedy_generate( + model=model, + input_ids=input_ids, attention_mask=attention_mask, - do_sample=False, - num_beams=1, max_length=max_length, - **logits_process_kwargs, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, ) if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, input_ids, attention_mask - ) - kwargs["encoder_outputs"] = encoder_outputs + self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput) + self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) + else: + self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput) + self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) - with torch.no_grad(): - output_ids_greedy = model.greedy_search( - input_ids, - max_length=max_length, - attention_mask=attention_mask, - logits_processor=logits_processor, - **kwargs, - ) - self.assertListEqual(output_ids_generate.tolist(), output_ids_greedy.tolist()) + self.assertListEqual(output_generate.sequences.tolist(), output_greedy.sequences.tolist()) + + for output in (output_greedy, output_generate): + self._check_outputs(output, input_ids, model.config) + + def test_greedy_generate_dict_outputs_use_cache(self): + for model_class in self.all_generative_model_classes: + # enable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + if not hasattr(config, "use_cache"): + # only relevant if model has "use_cache" + return + + config.use_cache = True + model = model_class(config).to(torch_device).eval() + output_greedy, output_generate = self._greedy_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertListEqual(output_generate.sequences.tolist(), output_greedy.sequences.tolist()) + + for output in (output_greedy, output_generate): + self._check_outputs(output, input_ids, model.config, use_cache=True) def test_sample_generate(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], config.eos_token_id + input_ids.shape[-1], model.config.eos_token_id ) logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) - model = model_class(config).to(torch_device) - model.eval() + if model.config.is_encoder_decoder: + max_length = 4 # check `generate()` and `sample()` are equal - if model.config.is_encoder_decoder: - max_length = 4 - - torch.manual_seed(0) - output_ids_generate = model.generate( - input_ids, - do_sample=True, - num_beams=1, - max_length=max_length, + output_sample, output_generate = self._sample_generate( + model=model, + input_ids=input_ids, attention_mask=attention_mask, - **logits_warper_kwargs, - **process_kwargs, + max_length=max_length, + num_return_sequences=1, + logits_processor=logits_processor, + logits_warper=logits_warper, + logits_warper_kwargs=logits_warper_kwargs, + process_kwargs=process_kwargs, ) - - torch.manual_seed(0) - kwargs = {} - if model.config.is_encoder_decoder: - encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( - model, input_ids, attention_mask - ) - kwargs["encoder_outputs"] = encoder_outputs - else: - attention_mask_clone = attention_mask - input_ids_clone = input_ids - - with torch.no_grad(): - output_ids_sample = model.sample( - input_ids_clone, - attention_mask=attention_mask_clone, - max_length=max_length, - logits_processor=logits_processor, - logits_warper=logits_warper, - **kwargs, - ) - self.assertListEqual(output_ids_generate.tolist(), output_ids_sample.tolist()) + self.assertListEqual(output_sample.tolist(), output_generate.tolist()) # check `generate()` and `sample()` yield equal results for `num_return_sequences` - num_return_sequences = 3 + output_sample, output_generate = self._sample_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + num_return_sequences=3, + logits_processor=logits_processor, + logits_warper=logits_warper, + logits_warper_kwargs=logits_warper_kwargs, + process_kwargs=process_kwargs, + ) + self.assertListEqual(output_sample.tolist(), output_generate.tolist()) + + def test_sample_generate_dict_output(self): + for model_class in self.all_generative_model_classes: + # disable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config.use_cache = False + model = model_class(config).to(torch_device).eval() + process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], model.config.eos_token_id + ) + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + if model.config.is_encoder_decoder: max_length = 4 - torch.manual_seed(0) - output_ids_generate = model.generate( - input_ids, - do_sample=True, - num_beams=1, - max_length=max_length, - num_return_sequences=num_return_sequences, + output_sample, output_generate = self._sample_generate( + model=model, + input_ids=input_ids, attention_mask=attention_mask, - **logits_warper_kwargs, - **process_kwargs, + max_length=max_length, + num_return_sequences=2, + logits_processor=logits_processor, + logits_warper=logits_warper, + logits_warper_kwargs=logits_warper_kwargs, + process_kwargs=process_kwargs, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, ) - torch.manual_seed(0) - kwargs = {} if model.config.is_encoder_decoder: - encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( - model, input_ids, attention_mask, num_interleave=num_return_sequences - ) - kwargs["encoder_outputs"] = encoder_outputs - input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0) + self.assertIsInstance(output_sample, SampleEncoderDecoderOutput) + self.assertIsInstance(output_generate, SampleEncoderDecoderOutput) else: - attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0) - input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0) + self.assertIsInstance(output_sample, SampleDecoderOnlyOutput) + self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) - with torch.no_grad(): - output_ids_sample = model.sample( - input_ids_clone, - attention_mask=attention_mask_clone, - max_length=max_length, - logits_processor=logits_processor, - logits_warper=logits_warper, - **kwargs, - ) - self.assertListEqual(output_ids_generate.tolist(), output_ids_sample.tolist()) + self.assertListEqual(output_generate.sequences.tolist(), output_sample.sequences.tolist()) + + for output in (output_sample, output_generate): + self._check_outputs(output, input_ids, model.config, num_return_sequences=2) def test_beam_search_generate(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( input_ids.shape[-1], config.eos_token_id ) - - model = model_class(config).to(torch_device) - model.eval() - - # check `generate()` and `beam_search()` are equal if model.config.is_encoder_decoder: max_length = 4 beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) - output_ids_generate = model.generate( - input_ids, + + # check `generate()` and `beam_search()` are equal + output_generate, output_beam_search = self._beam_search_generate( + model=model, + input_ids=input_ids, attention_mask=attention_mask, - do_sample=False, max_length=max_length, - **beam_kwargs, - **logits_process_kwargs, + beam_scorer=beam_scorer, + beam_kwargs=beam_kwargs, + logits_process_kwargs=logits_process_kwargs, + logits_processor=logits_processor, ) - - # beam_search does not automatically interleave `batch_size` dim for `num_beams` - kwargs = {} - if model.config.is_encoder_decoder: - encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( - model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams - ) - kwargs["encoder_outputs"] = encoder_outputs - input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) - else: - attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) - input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) - - with torch.no_grad(): - output_ids_beam_search = model.beam_search( - input_ids_clone, - beam_scorer, - max_length=max_length, - attention_mask=attention_mask_clone, - logits_processor=logits_processor, - **kwargs, - ) - self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist()) + self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) # check `generate()` and `beam_search()` are equal for `num_return_sequences` num_return_sequences = 2 @@ -340,36 +654,104 @@ class GenerationTesterMixin: input_ids.shape[0], max_length, num_return_sequences=num_return_sequences ) - output_ids_generate = model.generate( - input_ids, + output_generate, output_beam_search = self._beam_search_generate( + model=model, + input_ids=input_ids, attention_mask=attention_mask, - do_sample=False, max_length=max_length, - **beam_kwargs, - **logits_process_kwargs, + beam_scorer=beam_scorer, + beam_kwargs=beam_kwargs, + logits_process_kwargs=logits_process_kwargs, + logits_processor=logits_processor, ) - # beam_search does not automatically interleave `batch_size` dim for `num_beams` - kwargs = {} - if model.config.is_encoder_decoder: - encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( - model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams - ) - kwargs["encoder_outputs"] = encoder_outputs - input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) - else: - attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) - input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) + self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) - with torch.no_grad(): - output_ids_beam_search = model.beam_search( - input_ids_clone, - beam_scorer, - max_length=max_length, - attention_mask=attention_mask_clone, - logits_processor=logits_processor, - **kwargs, + def test_beam_search_generate_dict_output(self): + for model_class in self.all_generative_model_classes: + # disable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config.use_cache = False + model = model_class(config).to(torch_device).eval() + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], config.eos_token_id + ) + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + output_generate, output_beam_search = self._beam_search_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + beam_scorer=beam_scorer, + beam_kwargs=beam_kwargs, + logits_process_kwargs=logits_process_kwargs, + logits_processor=logits_processor, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + if model.config.is_encoder_decoder: + self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput) + self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) + else: + self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput) + self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) + + self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist()) + self.assertTrue( + torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3) + ) + self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) + self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) + + for output in (output_beam_search, output_generate): + self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams) + + def test_beam_search_generate_dict_outputs_use_cache(self): + for model_class in self.all_generative_model_classes: + # enable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + if not hasattr(config, "use_cache"): + # only relevant if model has "use_cache" + return + + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], config.eos_token_id + ) + + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + + config.use_cache = True + model = model_class(config).to(torch_device).eval() + output_beam, output_generate = self._beam_search_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + beam_scorer=beam_scorer, + beam_kwargs=beam_kwargs, + logits_process_kwargs=logits_process_kwargs, + logits_processor=logits_processor, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertListEqual(output_generate.sequences.tolist(), output_beam.sequences.tolist()) + + for output in (output_beam, output_generate): + self._check_outputs( + output, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams ) - self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist()) def test_beam_sample_generate(self): for model_class in self.all_generative_model_classes: @@ -377,8 +759,7 @@ class GenerationTesterMixin: print("Return dict", config.return_dict) logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) - model = model_class(config).to(torch_device) - model.eval() + model = model_class(config).to(torch_device).eval() # check `generate()` and `beam_search()` are equal # change `num_return_sequences = 2` but not for `beam_scorer` @@ -389,54 +770,88 @@ class GenerationTesterMixin: input_ids.shape[0] * num_return_sequences, max_length ) beam_kwargs["num_return_sequences"] = num_return_sequences - torch.manual_seed(0) - output_ids_generate = model.generate( - input_ids, + + output_generate, output_beam_sample = self._beam_sample_generate( + model=model, + input_ids=input_ids, attention_mask=attention_mask, - do_sample=True, max_length=max_length, - **beam_kwargs, - **logits_warper_kwargs, + num_return_sequences=num_return_sequences, + beam_scorer=beam_scorer, + beam_kwargs=beam_kwargs, + logits_warper=logits_warper, + logits_warper_kwargs=logits_warper_kwargs, ) - # beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences` - kwargs = {} + self.assertListEqual(output_generate.tolist(), output_beam_sample.tolist()) + + def test_beam_sample_generate_dict_output(self): + for model_class in self.all_generative_model_classes: + # disable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config.use_cache = False + model = model_class(config).to(torch_device).eval() + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + + num_return_sequences = 2 if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams * num_return_sequences - ) - kwargs["encoder_outputs"] = encoder_outputs + max_length = 4 + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( + input_ids.shape[0] * num_return_sequences, max_length + ) + beam_kwargs["num_return_sequences"] = num_return_sequences + + output_beam_sample, output_generate = self._beam_sample_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + num_return_sequences=num_return_sequences, + beam_scorer=beam_scorer, + beam_kwargs=beam_kwargs, + logits_warper=logits_warper, + logits_warper_kwargs=logits_warper_kwargs, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + if model.config.is_encoder_decoder: + self.assertIsInstance(output_beam_sample, BeamSearchEncoderDecoderOutput) + self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: - attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0) + self.assertIsInstance(output_beam_sample, BeamSearchDecoderOnlyOutput) + self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - torch.manual_seed(0) - with torch.no_grad(): - output_ids_beam_sample = model.beam_sample( - input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0), - beam_scorer, - max_length=max_length, - attention_mask=attention_mask, - logits_warper=logits_warper, - **kwargs, - ) - self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_sample.tolist()) + self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist()) + self.assertTrue( + torch.allclose(output_generate["sequences_scores"], output_beam_sample["sequences_scores"], atol=1e-3) + ) + self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) + self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) - def test_generate_without_input_ids(self): - config, _, _, max_length = self._get_input_ids_and_config() - - # if no bos token id => cannot generate from None - if config.bos_token_id is None: - return - - for model_class in self.all_generative_model_classes: - model = model_class(config).to(torch_device) - model.eval() - - output_ids_generate = model.generate( - do_sample=False, - max_length=max_length, + for output in (output_beam_sample, output_generate): + self._check_outputs( + output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams ) - self.assertIsNotNone(output_ids_generate) + def test_generate_without_input_ids(self): + config, _, _, max_length = self._get_input_ids_and_config() + + # if no bos token id => cannot generate from None + if config.bos_token_id is None: + return + + for model_class in self.all_generative_model_classes: + model = model_class(config).to(torch_device) + model.eval() + + output_ids_generate = model.generate( + do_sample=False, + max_length=max_length, + ) + + self.assertIsNotNone(output_ids_generate) def test_group_beam_search_generate(self): for model_class in self.all_generative_model_classes: @@ -446,44 +861,23 @@ class GenerationTesterMixin: input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0 ) - model = model_class(config).to(torch_device) - model.eval() + model = model_class(config).to(torch_device).eval() # check `generate()` and `group_beam_search()` are equal if model.config.is_encoder_decoder: max_length = 4 beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length) - output_ids_generate = model.generate( - input_ids, + output_generate, output_group_beam_search = self._group_beam_search_generate( + model=model, + input_ids=input_ids, attention_mask=attention_mask, - do_sample=False, max_length=max_length, - **beam_kwargs, - **logits_process_kwargs, + beam_scorer=beam_scorer, + beam_kwargs=beam_kwargs, + logits_processor=logits_processor, + logits_process_kwargs=logits_process_kwargs, ) - - # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` - kwargs = {} - if model.config.is_encoder_decoder: - encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( - model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams - ) - kwargs["encoder_outputs"] = encoder_outputs - input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) - else: - attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) - input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) - - with torch.no_grad(): - output_ids_group_beam_search = model.group_beam_search( - input_ids_clone, - beam_scorer, - max_length=max_length, - attention_mask=attention_mask_clone, - logits_processor=logits_processor, - **kwargs, - ) - self.assertListEqual(output_ids_generate.tolist(), output_ids_group_beam_search.tolist()) + self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist()) # check `generate()` and `group_beam_search()` are equal for `num_return_sequences` num_return_sequences = 2 @@ -492,37 +886,190 @@ class GenerationTesterMixin: beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( input_ids.shape[0], max_length, num_return_sequences=num_return_sequences ) - - output_ids_generate = model.generate( - input_ids, + output_generate, output_group_beam_search = self._group_beam_search_generate( + model=model, + input_ids=input_ids, attention_mask=attention_mask, - do_sample=False, max_length=max_length, - **beam_kwargs, - **logits_process_kwargs, + beam_scorer=beam_scorer, + beam_kwargs=beam_kwargs, + logits_processor=logits_processor, + logits_process_kwargs=logits_process_kwargs, ) - # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` - kwargs = {} - if model.config.is_encoder_decoder: - encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( - model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams - ) - kwargs["encoder_outputs"] = encoder_outputs - input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) - else: - attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) - input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) + self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist()) - with torch.no_grad(): - output_ids_beam_search = model.group_beam_search( - input_ids_clone, - beam_scorer, - max_length=max_length, - attention_mask=attention_mask_clone, - logits_processor=logits_processor, - **kwargs, + def test_group_beam_search_generate_dict_output(self): + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config.use_cache = False + model = model_class(config).to(torch_device).eval() + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0 + ) + + num_return_sequences = 1 + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( + input_ids.shape[0], max_length, num_return_sequences=num_return_sequences + ) + output_generate, output_group_beam_search = self._group_beam_search_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + beam_scorer=beam_scorer, + beam_kwargs=beam_kwargs, + logits_processor=logits_processor, + logits_process_kwargs=logits_process_kwargs, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + if model.config.is_encoder_decoder: + self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput) + self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) + else: + self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput) + self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) + + self.assertListEqual(output_generate.sequences.tolist(), output_group_beam_search.sequences.tolist()) + self.assertTrue( + torch.allclose( + output_generate["sequences_scores"], output_group_beam_search["sequences_scores"], atol=1e-3 ) - self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist()) + ) + self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) + self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) + + for output in (output_group_beam_search, output_generate): + self._check_outputs( + output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams + ) + + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): + batch_size, seq_length = input_ids.shape + num_sequences_in_output = batch_size * num_return_sequences + gen_len = ( + output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length + ) + + # scores + self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config) + + # Attentions + if config.is_encoder_decoder: + # encoder + encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length) + self.assertIsInstance(output.encoder_attentions, tuple) + self.assertListEqual( + [layer_attentions.shape for layer_attentions in output.encoder_attentions], + [encoder_expected_shape] * len(output.encoder_attentions), + ) + # decoder + self._check_attentions_for_generate( + num_sequences_in_output, + output.decoder_attentions, + min_length=1, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, + ) + else: + # if use_cache first input is equal to no use_cache, so skip here + attentions = output.attentions if not use_cache else output.attentions[1:] + min_length = seq_length if not use_cache else seq_length + 1 + self._check_attentions_for_generate( + num_sequences_in_output, + attentions=attentions, + min_length=min_length, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, + ) + + # Hidden States + if config.is_encoder_decoder: + # encoder + encoder_expected_shape = (batch_size, seq_length, config.hidden_size) + self.assertIsInstance(output.encoder_hidden_states, tuple) + self.assertListEqual( + [layer_hidden_states.shape for layer_hidden_states in output.encoder_hidden_states], + [encoder_expected_shape] * len(output.encoder_hidden_states), + ) + + # decoder + self._check_hidden_states_for_generate( + num_sequences_in_output, + output.decoder_hidden_states, + min_length=1, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, + ) + else: + # if use_cache first input is equal to no use_cache, so skip here + hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:] + min_length = seq_length if not use_cache else seq_length + 1 + self._check_hidden_states_for_generate( + num_sequences_in_output, + hidden_states, + min_length=min_length, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, + ) + + def _check_scores(self, batch_size, scores, length, config): + expected_shape = (batch_size, config.vocab_size) + self.assertIsInstance(scores, tuple) + self.assertEqual(len(scores), length) + self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) + + def _check_attentions_for_generate( + self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) + ) + self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) + + for idx, iter_attentions in enumerate(attentions): + tgt_len = min_length + idx if not use_cache else 1 + src_len = min_length + idx + + expected_shape = ( + batch_size * num_beam_groups, + config.num_attention_heads, + tgt_len, + src_len, + ) + # check attn size + self.assertListEqual( + [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) + ) + + def _check_hidden_states_for_generate( + self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(hidden_states, tuple) + self.assertListEqual( + [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], + [True] * len(hidden_states), + ) + self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) + + for idx, iter_hidden_states in enumerate(hidden_states): + seq_len = min_length + idx if not use_cache else 1 + expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) + # check hidden size + self.assertListEqual( + [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], + [expected_shape] * len(iter_hidden_states), + ) @require_torch diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 571da15f8b..817d35c5b9 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -638,6 +638,69 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod model = ReformerModelWithLMHead.from_pretrained(model_name) self.assertIsNotNone(model) + def _check_attentions_for_generate( + self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions) + ) + self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) + + for idx, iter_attentions in enumerate(attentions): + tgt_len = min_length + idx if not use_cache else 1 + num_chunks = tgt_len // config.local_attn_chunk_length + (tgt_len % config.local_attn_chunk_length != 0) + tgt_chunk_len = config.local_attn_chunk_length + src_chunk_len = config.local_attn_chunk_length * ( + 1 + config.local_num_chunks_after + config.local_num_chunks_before + ) + + if use_cache: + expected_shape = ( + batch_size * num_beam_groups, + config.num_attention_heads, + tgt_len, + min_length // config.local_attn_chunk_length + 1 + idx, + ) + else: + expected_shape = ( + batch_size * num_beam_groups, + config.num_attention_heads, + num_chunks, + tgt_chunk_len, + src_chunk_len, + ) + # check attn size + self.assertListEqual( + [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) + ) + + def _check_hidden_states_for_generate( + self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(hidden_states, tuple) + self.assertListEqual( + [isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states], + [True] * len(hidden_states), + ) + self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) + + for idx, iter_hidden_states in enumerate(hidden_states): + seq_len = min_length + idx + seq_len = config.local_attn_chunk_length * ( + seq_len // config.local_attn_chunk_length + (seq_len % config.local_attn_chunk_length != 0) + ) + + if use_cache: + seq_len = 1 + + expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) + # check hidden size + self.assertListEqual( + [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], + [expected_shape] * len(iter_hidden_states), + ) + @require_torch class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): @@ -696,13 +759,77 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, Generation self.model_tester = ReformerModelTester(self, **tester_kwargs) self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) + def _check_attentions_for_generate( + self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions) + ) + self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) + + for idx, iter_attentions in enumerate(attentions): + tgt_len = min_length + idx if not use_cache else 1 + num_chunks = tgt_len // config.lsh_attn_chunk_length + (tgt_len % config.lsh_attn_chunk_length != 0) + tgt_chunk_len = config.lsh_attn_chunk_length + src_chunk_len = config.lsh_attn_chunk_length * ( + 1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before + ) + + if use_cache: + expected_shape = ( + batch_size * num_beam_groups, + config.num_attention_heads, + config.num_hashes, + tgt_len, + config.num_hashes * (1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before), + ) + else: + expected_shape = ( + batch_size * num_beam_groups, + config.num_attention_heads, + num_chunks * config.num_hashes, + tgt_chunk_len, + src_chunk_len, + ) + # check attn size + self.assertListEqual( + [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) + ) + + def _check_hidden_states_for_generate( + self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(hidden_states, tuple) + self.assertListEqual( + [isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states], + [True] * len(hidden_states), + ) + self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) + + for idx, iter_hidden_states in enumerate(hidden_states): + seq_len = min_length + idx if not use_cache else 1 + seq_len = config.lsh_attn_chunk_length * ( + seq_len // config.lsh_attn_chunk_length + (seq_len % config.lsh_attn_chunk_length != 0) + ) + + if use_cache: + seq_len = 1 + + expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) + # check hidden size + self.assertListEqual( + [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], + [expected_shape] * len(iter_hidden_states), + ) + @require_torch @require_sentencepiece @require_tokenizers class ReformerIntegrationTests(unittest.TestCase): """ - These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/06/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_trax_tests`. + These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/06/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "lsh" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_trax_tests`. """ def _get_basic_config_and_input(self): diff --git a/tests/test_modeling_transfo_xl.py b/tests/test_modeling_transfo_xl.py index 20e15bd1a4..6f771ece01 100644 --- a/tests/test_modeling_transfo_xl.py +++ b/tests/test_modeling_transfo_xl.py @@ -304,6 +304,50 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC # transfo-xl requires special resize for lm-head return + def _check_attentions_for_generate( + self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) + ) + self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) + + for idx, iter_attentions in enumerate(attentions): + tgt_len = min_length if idx == 0 else (min_length - 2) + src_len = (min_length + config.mem_len) if idx == 0 else (min_length + config.mem_len - 2) + + expected_shape = ( + batch_size * num_beam_groups, + config.num_attention_heads, + tgt_len, + src_len, + ) + + # check attn size + self.assertListEqual( + [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) + ) + + def _check_hidden_states_for_generate( + self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(hidden_states, tuple) + self.assertListEqual( + [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], + [True] * len(hidden_states), + ) + self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) + + for idx, iter_hidden_states in enumerate(hidden_states): + seq_len = min_length if idx == 0 else min_length - 2 + expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) + # check hidden size + self.assertListEqual( + [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], + [expected_shape] * len(iter_hidden_states), + ) + @require_torch class TransfoXLModelLanguageGenerationTest(unittest.TestCase): diff --git a/tests/test_modeling_xlm.py b/tests/test_modeling_xlm.py index 57ab48ab52..69f76b88c9 100644 --- a/tests/test_modeling_xlm.py +++ b/tests/test_modeling_xlm.py @@ -400,6 +400,52 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs) + def _check_attentions_for_generate( + self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) + ) + self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) + + for idx, iter_attentions in enumerate(attentions): + # adds PAD dummy token + tgt_len = min_length + idx + 1 + src_len = min_length + idx + 1 + + expected_shape = ( + batch_size * num_beam_groups, + config.num_attention_heads, + tgt_len, + src_len, + ) + # check attn size + self.assertListEqual( + [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) + ) + + def _check_hidden_states_for_generate( + self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(hidden_states, tuple) + self.assertListEqual( + [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], + [True] * len(hidden_states), + ) + self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) + + for idx, iter_hidden_states in enumerate(hidden_states): + # adds PAD dummy token + seq_len = min_length + idx + 1 + expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) + # check hidden size + self.assertListEqual( + [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], + [expected_shape] * len(iter_hidden_states), + ) + pass + @slow def test_model_from_pretrained(self): for model_name in XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index b1888e0f18..1423ef6980 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -593,6 +593,60 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) # xlnet cannot keep gradients in attentions or hidden states return + def _check_hidden_states_for_generate( + self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(hidden_states, tuple) + self.assertListEqual( + [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], + [True] * len(hidden_states), + ) + self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) + + for idx, iter_hidden_states in enumerate(hidden_states): + # check hidden size + for i, layer_hidden_states in enumerate(iter_hidden_states): + # every 2nd tensor is from extra stream + if i % 2 != 0: + seq_len = 1 + else: + # for first item dummy PAD token is appended so need one more + seq_len = (min_length + 1) if idx == 0 else min_length + + expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) + self.assertEqual(layer_hidden_states.shape, expected_shape) + + def _check_attentions_for_generate( + self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) + ) + self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) + + for idx, attentions_item in enumerate(attentions): + for iter_attentions in attentions_item: + tgt_len = min_length + + # for first item dummy PAD token is appended so need one more + if idx == 0: + tgt_len += 1 + + src_len = min_length + idx + 1 + + expected_shape = ( + batch_size * num_beam_groups, + config.num_attention_heads, + tgt_len, + src_len, + ) + # check attn size + self.assertListEqual( + [layer_attention.shape for layer_attention in iter_attentions], + [expected_shape] * len(iter_attentions), + ) + @slow def test_model_from_pretrained(self): for model_name in XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: