Add flags to return scores, hidden states and / or attention weights in GenerationMixin (#9150)
* Define new output dataclasses for greedy generation * Add output_[...] flags in greedy generation methods Added output_attentions, output_hidden_states, output_scores flags in generate and greedy_search methods in GenerationMixin. * [WIP] Implement logic and tests for output flags in generation * Update GreedySearchOutput classes & docstring * Implement greedy search output accumulation logic Update greedy_search unittests Fix generate method return value docstring Properly init flags with the default config * Update configuration to add output_scores flag * Fix test_generation_utils Sort imports and fix isinstance tests for GreedySearchOutputs * Fix typo in generation_utils * Add return_dict_in_generate for backwards compatibility * Add return_dict_in_generate flag in config * Fix tyPo in configuration * Fix handling of attentions and hidden_states flags * Make style & quality * first attempt attentions * some corrections * improve tests * special models requires special test * disable xlm test for now * clean tests * fix for tf * isort * Add output dataclasses for other generation methods * Add logic to return dict in sample generation * Complete test for sample generation - Pass output_attentions and output_hidden_states flags to encoder in encoder-decoder models - Fix import satements order in test_generation_utils file * Add logic to return dict in sample generation - Refactor tests to avoid using self.assertTrue, which provides scarce information when the test fails - Add tests for the three beam_search methods: vanilla, sample and grouped * Style doc * Fix copy-paste error in generation tests * Rename logits to scores and refactor * Refactor group_beam_search for consistency * make style * add sequences_scores * fix all tests * add docs * fix beam search finalize test * correct docstring * clean some files * Made suggested changes to the documentation * Style doc ? * Style doc using the Python util * Update src/transformers/generation_utils.py * fix empty lines * fix all test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -124,6 +124,11 @@ class PretrainedConfig(object):
|
||||
- **num_return_sequences** (:obj:`int`, `optional`, defaults to 1) -- Number of independently computed returned
|
||||
sequences for each element in the batch that will be used by default in the :obj:`generate` method of the
|
||||
model.
|
||||
- **output_scores** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether the model should return the
|
||||
logits when used for generation
|
||||
- **return_dict_in_generate** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether the model should
|
||||
return a :class:`~transformers.file_utils.ModelOutput` instead of a :obj:`torch.LongTensor`
|
||||
|
||||
|
||||
Parameters for fine-tuning tasks
|
||||
|
||||
@@ -203,6 +208,8 @@ class PretrainedConfig(object):
|
||||
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
||||
self.output_scores = kwargs.pop("output_scores", False)
|
||||
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
|
||||
|
||||
# Fine-tuning task arguments
|
||||
self.architectures = kwargs.pop("architectures", None)
|
||||
@@ -343,6 +350,7 @@ class PretrainedConfig(object):
|
||||
|
||||
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
|
||||
|
||||
|
||||
Returns:
|
||||
:class:`PretrainedConfig`: The configuration object instantiated from this pretrained model.
|
||||
|
||||
@@ -372,6 +380,8 @@ class PretrainedConfig(object):
|
||||
From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a
|
||||
:class:`~transformers.PretrainedConfig` using ``from_dict``.
|
||||
|
||||
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
||||
|
||||
@@ -281,7 +281,7 @@ class BeamSearchScorer(BeamScorer):
|
||||
final_beam_indices: torch.LongTensor,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
) -> torch.LongTensor:
|
||||
) -> Tuple[torch.LongTensor]:
|
||||
batch_size = len(self._beam_hyps)
|
||||
|
||||
# finalize all open beam hypotheses and add to generated hypotheses
|
||||
@@ -300,14 +300,20 @@ class BeamSearchScorer(BeamScorer):
|
||||
# select the best hypotheses
|
||||
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
||||
best = []
|
||||
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
|
||||
|
||||
# retrieve best hypotheses
|
||||
for i, beam_hyp in enumerate(self._beam_hyps):
|
||||
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
|
||||
for j in range(self.num_beam_hyps_to_keep):
|
||||
best_hyp = sorted_hyps.pop()[1]
|
||||
best_hyp_tuple = sorted_hyps.pop()
|
||||
best_score = best_hyp_tuple[0]
|
||||
best_hyp = best_hyp_tuple[1]
|
||||
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
|
||||
|
||||
# append to lists
|
||||
best.append(best_hyp)
|
||||
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
||||
|
||||
# prepare for adding eos
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
|
||||
@@ -322,7 +328,12 @@ class BeamSearchScorer(BeamScorer):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < self.max_length:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
return decoded
|
||||
return UserDict(
|
||||
{
|
||||
"sequences": decoded,
|
||||
"sequence_scores": best_scores,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class BeamHypotheses:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -136,7 +136,6 @@ class OpenAIGPTConfig(PretrainedConfig):
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
summary_first_dropout=0.1,
|
||||
use_cache=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@@ -159,7 +158,6 @@ class OpenAIGPTConfig(PretrainedConfig):
|
||||
self.summary_activation = summary_activation
|
||||
self.summary_first_dropout = summary_first_dropout
|
||||
self.summary_proj_to_labels = summary_proj_to_labels
|
||||
self.use_cache = use_cache
|
||||
|
||||
@property
|
||||
def max_position_embeddings(self):
|
||||
|
||||
Reference in New Issue
Block a user