Add flags to return scores, hidden states and / or attention weights in GenerationMixin (#9150)
* Define new output dataclasses for greedy generation * Add output_[...] flags in greedy generation methods Added output_attentions, output_hidden_states, output_scores flags in generate and greedy_search methods in GenerationMixin. * [WIP] Implement logic and tests for output flags in generation * Update GreedySearchOutput classes & docstring * Implement greedy search output accumulation logic Update greedy_search unittests Fix generate method return value docstring Properly init flags with the default config * Update configuration to add output_scores flag * Fix test_generation_utils Sort imports and fix isinstance tests for GreedySearchOutputs * Fix typo in generation_utils * Add return_dict_in_generate for backwards compatibility * Add return_dict_in_generate flag in config * Fix tyPo in configuration * Fix handling of attentions and hidden_states flags * Make style & quality * first attempt attentions * some corrections * improve tests * special models requires special test * disable xlm test for now * clean tests * fix for tf * isort * Add output dataclasses for other generation methods * Add logic to return dict in sample generation * Complete test for sample generation - Pass output_attentions and output_hidden_states flags to encoder in encoder-decoder models - Fix import satements order in test_generation_utils file * Add logic to return dict in sample generation - Refactor tests to avoid using self.assertTrue, which provides scarce information when the test fails - Add tests for the three beam_search methods: vanilla, sample and grouped * Style doc * Fix copy-paste error in generation tests * Rename logits to scores and refactor * Refactor group_beam_search for consistency * make style * add sequences_scores * fix all tests * add docs * fix beam search finalize test * correct docstring * clean some files * Made suggested changes to the documentation * Style doc ? * Style doc using the Python util * Update src/transformers/generation_utils.py * fix empty lines * fix all test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -593,6 +593,60 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
# xlnet cannot keep gradients in attentions or hidden states
|
||||
return
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
||||
[True] * len(hidden_states),
|
||||
)
|
||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
||||
# check hidden size
|
||||
for i, layer_hidden_states in enumerate(iter_hidden_states):
|
||||
# every 2nd tensor is from extra stream
|
||||
if i % 2 != 0:
|
||||
seq_len = 1
|
||||
else:
|
||||
# for first item dummy PAD token is appended so need one more
|
||||
seq_len = (min_length + 1) if idx == 0 else min_length
|
||||
|
||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||
self.assertEqual(layer_hidden_states.shape, expected_shape)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, attentions_item in enumerate(attentions):
|
||||
for iter_attentions in attentions_item:
|
||||
tgt_len = min_length
|
||||
|
||||
# for first item dummy PAD token is appended so need one more
|
||||
if idx == 0:
|
||||
tgt_len += 1
|
||||
|
||||
src_len = min_length + idx + 1
|
||||
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
)
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
[layer_attention.shape for layer_attention in iter_attentions],
|
||||
[expected_shape] * len(iter_attentions),
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
||||
Reference in New Issue
Block a user