Add flags to return scores, hidden states and / or attention weights in GenerationMixin (#9150)

* Define new output dataclasses for greedy generation

* Add output_[...] flags in greedy generation methods

Added output_attentions, output_hidden_states, output_scores flags in
generate and greedy_search methods in GenerationMixin.

* [WIP] Implement logic and tests for output flags in generation

* Update GreedySearchOutput classes & docstring

* Implement greedy search output accumulation logic

Update greedy_search unittests

Fix generate method return value docstring

Properly init flags with the default config

* Update configuration to add output_scores flag

* Fix test_generation_utils

Sort imports and fix isinstance tests for GreedySearchOutputs

* Fix typo in generation_utils

* Add return_dict_in_generate for backwards compatibility

* Add return_dict_in_generate flag in config

* Fix tyPo in configuration

* Fix handling of attentions and hidden_states flags

* Make style & quality

* first attempt attentions

* some corrections

* improve tests

* special models requires special test

* disable xlm test for now

* clean tests

* fix for tf

* isort

* Add output dataclasses for other generation methods

* Add logic to return dict in sample generation

* Complete test for sample generation

- Pass output_attentions and output_hidden_states flags to encoder in
encoder-decoder models
- Fix import satements order in test_generation_utils file

* Add logic to return dict in sample generation

- Refactor tests to avoid using self.assertTrue, which provides
scarce information when the test fails
- Add tests for the three beam_search methods: vanilla, sample and
grouped

* Style doc

* Fix copy-paste error in generation tests

* Rename logits to scores and refactor

* Refactor group_beam_search for consistency

* make style

* add sequences_scores

* fix all tests

* add docs

* fix beam search finalize test

* correct docstring

* clean some files

* Made suggested changes to the documentation

* Style doc ?

* Style doc using the Python util

* Update src/transformers/generation_utils.py

* fix empty lines

* fix all test

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Simon Brandeis
2021-01-06 17:11:42 +01:00
committed by GitHub
parent 7a9f1b5c99
commit c89f1bc92e
11 changed files with 2014 additions and 321 deletions

View File

@@ -13,13 +13,102 @@
Utilities for Generation
-----------------------------------------------------------------------------------------------------------------------
This page lists all the utility functions used by :meth:`~transformers.PretrainedModel.generate`,
:meth:`~transformers.PretrainedModel.greedy_search`, :meth:`~transformers.PretrainedModel.sample`,
:meth:`~transformers.PretrainedModel.beam_search`, :meth:`~transformers.PretrainedModel.beam_sample`, and
:meth:`~transformers.PretrainedModel.group_beam_search`.
This page lists all the utility functions used by :meth:`~transformers.PreTrainedModel.generate`,
:meth:`~transformers.PreTrainedModel.greedy_search`, :meth:`~transformers.PreTrainedModel.sample`,
:meth:`~transformers.PreTrainedModel.beam_search`, :meth:`~transformers.PreTrainedModel.beam_sample`, and
:meth:`~transformers.PreTrainedModel.group_beam_search`.
Most of those are only useful if you are studying the code of the generate methods in the library.
Generate Outputs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The output of :meth:`~transformers.PreTrainedModel.generate` is an instance of a subclass of
:class:`~transformers.file_utils.ModelOutput`. This output is a data structure containing all the information returned
by :meth:`~transformers.PreTrainedModel.generate`, but that can also be used as tuple or dictionary.
Here's an example:
.. code-block::
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
The ``generation_output`` object is a :class:`~transformers.generation_utils.GreedySearchDecoderOnlyOutput`, as we can
see in the documentation of that class below, it means it has the following attributes:
- ``sequences``: the generated sequences of tokens
- ``scores`` (optional): the prediction scores of the language modelling head, for each generation step
- ``hidden_states`` (optional): the hidden states of the model, for each generation step
- ``attentions`` (optional): the attention weights of the model, for each generation step
Here we have the ``scores`` since we passed along ``output_scores=True``, but we don't have ``hidden_states`` and
``attentions`` because we didn't pass ``output_hidden_states=True`` or ``output_attentions=True``.
You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you
will get ``None``. Here for instance ``generation_output.scores`` are all the generated prediction scores of the
language modeling head, and ``generation_output.attentions`` is ``None``.
When using our ``generation_output`` object as a tuple, it only keeps the attributes that don't have ``None`` values.
Here, for instance, it has two elements, ``loss`` then ``logits``, so
.. code-block::
generation_output[:2]
will return the tuple ``(generation_output.sequences, generation_output.scores)`` for instance.
When using our ``generation_output`` object as a dictionary, it only keeps the attributes that don't have ``None``
values. Here, for instance, it has two keys that are ``sequences`` and ``scores``.
We document here all output types.
GreedySearchOutput
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: transformers.generation_utils.GreedySearchDecoderOnlyOutput
:members:
.. autoclass:: transformers.generation_utils.GreedySearchEncoderDecoderOutput
:members:
SampleOutput
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: transformers.generation_utils.SampleDecoderOnlyOutput
:members:
.. autoclass:: transformers.generation_utils.SampleEncoderDecoderOutput
:members:
BeamSearchOutput
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: transformers.generation_utils.BeamSearchDecoderOnlyOutput
:members:
.. autoclass:: transformers.generation_utils.BeamSearchEncoderDecoderOutput
:members:
BeamSampleOutput
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: transformers.generation_utils.BeamSampleDecoderOnlyOutput
:members:
.. autoclass:: transformers.generation_utils.BeamSampleEncoderDecoderOutput
:members:
LogitsProcessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~