Add flags to return scores, hidden states and / or attention weights in GenerationMixin (#9150)
* Define new output dataclasses for greedy generation * Add output_[...] flags in greedy generation methods Added output_attentions, output_hidden_states, output_scores flags in generate and greedy_search methods in GenerationMixin. * [WIP] Implement logic and tests for output flags in generation * Update GreedySearchOutput classes & docstring * Implement greedy search output accumulation logic Update greedy_search unittests Fix generate method return value docstring Properly init flags with the default config * Update configuration to add output_scores flag * Fix test_generation_utils Sort imports and fix isinstance tests for GreedySearchOutputs * Fix typo in generation_utils * Add return_dict_in_generate for backwards compatibility * Add return_dict_in_generate flag in config * Fix tyPo in configuration * Fix handling of attentions and hidden_states flags * Make style & quality * first attempt attentions * some corrections * improve tests * special models requires special test * disable xlm test for now * clean tests * fix for tf * isort * Add output dataclasses for other generation methods * Add logic to return dict in sample generation * Complete test for sample generation - Pass output_attentions and output_hidden_states flags to encoder in encoder-decoder models - Fix import satements order in test_generation_utils file * Add logic to return dict in sample generation - Refactor tests to avoid using self.assertTrue, which provides scarce information when the test fails - Add tests for the three beam_search methods: vanilla, sample and grouped * Style doc * Fix copy-paste error in generation tests * Rename logits to scores and refactor * Refactor group_beam_search for consistency * make style * add sequences_scores * fix all tests * add docs * fix beam search finalize test * correct docstring * clean some files * Made suggested changes to the documentation * Style doc ? * Style doc using the Python util * Update src/transformers/generation_utils.py * fix empty lines * fix all test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -13,13 +13,102 @@
|
||||
Utilities for Generation
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
This page lists all the utility functions used by :meth:`~transformers.PretrainedModel.generate`,
|
||||
:meth:`~transformers.PretrainedModel.greedy_search`, :meth:`~transformers.PretrainedModel.sample`,
|
||||
:meth:`~transformers.PretrainedModel.beam_search`, :meth:`~transformers.PretrainedModel.beam_sample`, and
|
||||
:meth:`~transformers.PretrainedModel.group_beam_search`.
|
||||
This page lists all the utility functions used by :meth:`~transformers.PreTrainedModel.generate`,
|
||||
:meth:`~transformers.PreTrainedModel.greedy_search`, :meth:`~transformers.PreTrainedModel.sample`,
|
||||
:meth:`~transformers.PreTrainedModel.beam_search`, :meth:`~transformers.PreTrainedModel.beam_sample`, and
|
||||
:meth:`~transformers.PreTrainedModel.group_beam_search`.
|
||||
|
||||
Most of those are only useful if you are studying the code of the generate methods in the library.
|
||||
|
||||
Generate Outputs
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The output of :meth:`~transformers.PreTrainedModel.generate` is an instance of a subclass of
|
||||
:class:`~transformers.file_utils.ModelOutput`. This output is a data structure containing all the information returned
|
||||
by :meth:`~transformers.PreTrainedModel.generate`, but that can also be used as tuple or dictionary.
|
||||
|
||||
Here's an example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
model = GPT2LMHeadModel.from_pretrained('gpt2')
|
||||
|
||||
inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
|
||||
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
|
||||
|
||||
The ``generation_output`` object is a :class:`~transformers.generation_utils.GreedySearchDecoderOnlyOutput`, as we can
|
||||
see in the documentation of that class below, it means it has the following attributes:
|
||||
|
||||
- ``sequences``: the generated sequences of tokens
|
||||
- ``scores`` (optional): the prediction scores of the language modelling head, for each generation step
|
||||
- ``hidden_states`` (optional): the hidden states of the model, for each generation step
|
||||
- ``attentions`` (optional): the attention weights of the model, for each generation step
|
||||
|
||||
Here we have the ``scores`` since we passed along ``output_scores=True``, but we don't have ``hidden_states`` and
|
||||
``attentions`` because we didn't pass ``output_hidden_states=True`` or ``output_attentions=True``.
|
||||
|
||||
You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you
|
||||
will get ``None``. Here for instance ``generation_output.scores`` are all the generated prediction scores of the
|
||||
language modeling head, and ``generation_output.attentions`` is ``None``.
|
||||
|
||||
When using our ``generation_output`` object as a tuple, it only keeps the attributes that don't have ``None`` values.
|
||||
Here, for instance, it has two elements, ``loss`` then ``logits``, so
|
||||
|
||||
.. code-block::
|
||||
|
||||
generation_output[:2]
|
||||
|
||||
will return the tuple ``(generation_output.sequences, generation_output.scores)`` for instance.
|
||||
|
||||
When using our ``generation_output`` object as a dictionary, it only keeps the attributes that don't have ``None``
|
||||
values. Here, for instance, it has two keys that are ``sequences`` and ``scores``.
|
||||
|
||||
We document here all output types.
|
||||
|
||||
|
||||
GreedySearchOutput
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: transformers.generation_utils.GreedySearchDecoderOnlyOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.generation_utils.GreedySearchEncoderDecoderOutput
|
||||
:members:
|
||||
|
||||
|
||||
SampleOutput
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: transformers.generation_utils.SampleDecoderOnlyOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.generation_utils.SampleEncoderDecoderOutput
|
||||
:members:
|
||||
|
||||
|
||||
BeamSearchOutput
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: transformers.generation_utils.BeamSearchDecoderOnlyOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.generation_utils.BeamSearchEncoderDecoderOutput
|
||||
:members:
|
||||
|
||||
|
||||
BeamSampleOutput
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: transformers.generation_utils.BeamSampleDecoderOnlyOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.generation_utils.BeamSampleEncoderDecoderOutput
|
||||
:members:
|
||||
|
||||
|
||||
LogitsProcessor
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
Reference in New Issue
Block a user