Add flags to return scores, hidden states and / or attention weights in GenerationMixin (#9150)
* Define new output dataclasses for greedy generation * Add output_[...] flags in greedy generation methods Added output_attentions, output_hidden_states, output_scores flags in generate and greedy_search methods in GenerationMixin. * [WIP] Implement logic and tests for output flags in generation * Update GreedySearchOutput classes & docstring * Implement greedy search output accumulation logic Update greedy_search unittests Fix generate method return value docstring Properly init flags with the default config * Update configuration to add output_scores flag * Fix test_generation_utils Sort imports and fix isinstance tests for GreedySearchOutputs * Fix typo in generation_utils * Add return_dict_in_generate for backwards compatibility * Add return_dict_in_generate flag in config * Fix tyPo in configuration * Fix handling of attentions and hidden_states flags * Make style & quality * first attempt attentions * some corrections * improve tests * special models requires special test * disable xlm test for now * clean tests * fix for tf * isort * Add output dataclasses for other generation methods * Add logic to return dict in sample generation * Complete test for sample generation - Pass output_attentions and output_hidden_states flags to encoder in encoder-decoder models - Fix import satements order in test_generation_utils file * Add logic to return dict in sample generation - Refactor tests to avoid using self.assertTrue, which provides scarce information when the test fails - Add tests for the three beam_search methods: vanilla, sample and grouped * Style doc * Fix copy-paste error in generation tests * Rename logits to scores and refactor * Refactor group_beam_search for consistency * make style * add sequences_scores * fix all tests * add docs * fix beam search finalize test * correct docstring * clean some files * Made suggested changes to the documentation * Style doc ? * Style doc using the Python util * Update src/transformers/generation_utils.py * fix empty lines * fix all test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -638,6 +638,69 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
||||
model = ReformerModelWithLMHead.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_attentions in enumerate(attentions):
|
||||
tgt_len = min_length + idx if not use_cache else 1
|
||||
num_chunks = tgt_len // config.local_attn_chunk_length + (tgt_len % config.local_attn_chunk_length != 0)
|
||||
tgt_chunk_len = config.local_attn_chunk_length
|
||||
src_chunk_len = config.local_attn_chunk_length * (
|
||||
1 + config.local_num_chunks_after + config.local_num_chunks_before
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
min_length // config.local_attn_chunk_length + 1 + idx,
|
||||
)
|
||||
else:
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
num_chunks,
|
||||
tgt_chunk_len,
|
||||
src_chunk_len,
|
||||
)
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
||||
)
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
|
||||
[True] * len(hidden_states),
|
||||
)
|
||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
||||
seq_len = min_length + idx
|
||||
seq_len = config.local_attn_chunk_length * (
|
||||
seq_len // config.local_attn_chunk_length + (seq_len % config.local_attn_chunk_length != 0)
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
seq_len = 1
|
||||
|
||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||
# check hidden size
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||
[expected_shape] * len(iter_hidden_states),
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
@@ -696,13 +759,77 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, Generation
|
||||
self.model_tester = ReformerModelTester(self, **tester_kwargs)
|
||||
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_attentions in enumerate(attentions):
|
||||
tgt_len = min_length + idx if not use_cache else 1
|
||||
num_chunks = tgt_len // config.lsh_attn_chunk_length + (tgt_len % config.lsh_attn_chunk_length != 0)
|
||||
tgt_chunk_len = config.lsh_attn_chunk_length
|
||||
src_chunk_len = config.lsh_attn_chunk_length * (
|
||||
1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
config.num_hashes,
|
||||
tgt_len,
|
||||
config.num_hashes * (1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before),
|
||||
)
|
||||
else:
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
num_chunks * config.num_hashes,
|
||||
tgt_chunk_len,
|
||||
src_chunk_len,
|
||||
)
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
||||
)
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
|
||||
[True] * len(hidden_states),
|
||||
)
|
||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
||||
seq_len = min_length + idx if not use_cache else 1
|
||||
seq_len = config.lsh_attn_chunk_length * (
|
||||
seq_len // config.lsh_attn_chunk_length + (seq_len % config.lsh_attn_chunk_length != 0)
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
seq_len = 1
|
||||
|
||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||
# check hidden size
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||
[expected_shape] * len(iter_hidden_states),
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class ReformerIntegrationTests(unittest.TestCase):
|
||||
"""
|
||||
These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/06/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_trax_tests`.
|
||||
These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/06/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "lsh" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_trax_tests`.
|
||||
"""
|
||||
|
||||
def _get_basic_config_and_input(self):
|
||||
|
||||
Reference in New Issue
Block a user