Add flags to return scores, hidden states and / or attention weights in GenerationMixin (#9150)

* Define new output dataclasses for greedy generation

* Add output_[...] flags in greedy generation methods

Added output_attentions, output_hidden_states, output_scores flags in
generate and greedy_search methods in GenerationMixin.

* [WIP] Implement logic and tests for output flags in generation

* Update GreedySearchOutput classes & docstring

* Implement greedy search output accumulation logic

Update greedy_search unittests

Fix generate method return value docstring

Properly init flags with the default config

* Update configuration to add output_scores flag

* Fix test_generation_utils

Sort imports and fix isinstance tests for GreedySearchOutputs

* Fix typo in generation_utils

* Add return_dict_in_generate for backwards compatibility

* Add return_dict_in_generate flag in config

* Fix tyPo in configuration

* Fix handling of attentions and hidden_states flags

* Make style & quality

* first attempt attentions

* some corrections

* improve tests

* special models requires special test

* disable xlm test for now

* clean tests

* fix for tf

* isort

* Add output dataclasses for other generation methods

* Add logic to return dict in sample generation

* Complete test for sample generation

- Pass output_attentions and output_hidden_states flags to encoder in
encoder-decoder models
- Fix import satements order in test_generation_utils file

* Add logic to return dict in sample generation

- Refactor tests to avoid using self.assertTrue, which provides
scarce information when the test fails
- Add tests for the three beam_search methods: vanilla, sample and
grouped

* Style doc

* Fix copy-paste error in generation tests

* Rename logits to scores and refactor

* Refactor group_beam_search for consistency

* make style

* add sequences_scores

* fix all tests

* add docs

* fix beam search finalize test

* correct docstring

* clean some files

* Made suggested changes to the documentation

* Style doc ?

* Style doc using the Python util

* Update src/transformers/generation_utils.py

* fix empty lines

* fix all test

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Simon Brandeis
2021-01-06 17:11:42 +01:00
committed by GitHub
parent 7a9f1b5c99
commit c89f1bc92e
11 changed files with 2014 additions and 321 deletions

View File

@@ -190,7 +190,7 @@ class BeamSearchTester:
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize
decoded = beam_scorer.finalize(
sequence_output = beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
@@ -198,19 +198,27 @@ class BeamSearchTester:
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
self.parent.assertListEqual(list(decoded.shape), [self.batch_size, max_length])
self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size])
# check sequence_scores
self.parent.assertFalse((sequence_scores > 0).any().item())
# first batch has to finish with eos_token
self.parent.assertEqual(decoded[0, -1].item(), self.eos_token_id)
self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id)
# other batches cannot finish with eos token
self.parent.assertNotEqual(decoded[1, -1].item(), self.eos_token_id)
self.parent.assertNotEqual(decoded[2, -1].item(), self.eos_token_id)
self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id)
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
beam_scorer.num_beam_hyps_to_keep = self.num_beams
decoded = beam_scorer.finalize(
sequence_output = beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
@@ -218,7 +226,11 @@ class BeamSearchTester:
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
)
self.parent.assertListEqual(list(decoded.shape), [self.num_beams * self.batch_size, max_length])
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
@require_torch

File diff suppressed because it is too large Load Diff

View File

@@ -638,6 +638,69 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
model = ReformerModelWithLMHead.from_pretrained(model_name)
self.assertIsNotNone(model)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1
num_chunks = tgt_len // config.local_attn_chunk_length + (tgt_len % config.local_attn_chunk_length != 0)
tgt_chunk_len = config.local_attn_chunk_length
src_chunk_len = config.local_attn_chunk_length * (
1 + config.local_num_chunks_after + config.local_num_chunks_before
)
if use_cache:
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
min_length // config.local_attn_chunk_length + 1 + idx,
)
else:
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
num_chunks,
tgt_chunk_len,
src_chunk_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
seq_len = min_length + idx
seq_len = config.local_attn_chunk_length * (
seq_len // config.local_attn_chunk_length + (seq_len % config.local_attn_chunk_length != 0)
)
if use_cache:
seq_len = 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
@require_torch
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
@@ -696,13 +759,77 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, Generation
self.model_tester = ReformerModelTester(self, **tester_kwargs)
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1
num_chunks = tgt_len // config.lsh_attn_chunk_length + (tgt_len % config.lsh_attn_chunk_length != 0)
tgt_chunk_len = config.lsh_attn_chunk_length
src_chunk_len = config.lsh_attn_chunk_length * (
1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before
)
if use_cache:
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
config.num_hashes,
tgt_len,
config.num_hashes * (1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before),
)
else:
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
num_chunks * config.num_hashes,
tgt_chunk_len,
src_chunk_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
seq_len = min_length + idx if not use_cache else 1
seq_len = config.lsh_attn_chunk_length * (
seq_len // config.lsh_attn_chunk_length + (seq_len % config.lsh_attn_chunk_length != 0)
)
if use_cache:
seq_len = 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
@require_torch
@require_sentencepiece
@require_tokenizers
class ReformerIntegrationTests(unittest.TestCase):
"""
These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/06/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_trax_tests`.
These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/06/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "lsh" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_trax_tests`.
"""
def _get_basic_config_and_input(self):

View File

@@ -304,6 +304,50 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
# transfo-xl requires special resize for lm-head
return
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length if idx == 0 else (min_length - 2)
src_len = (min_length + config.mem_len) if idx == 0 else (min_length + config.mem_len - 2)
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
seq_len = min_length if idx == 0 else min_length - 2
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
@require_torch
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):

View File

@@ -400,6 +400,52 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, iter_attentions in enumerate(attentions):
# adds PAD dummy token
tgt_len = min_length + idx + 1
src_len = min_length + idx + 1
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
# adds PAD dummy token
seq_len = min_length + idx + 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
pass
@slow
def test_model_from_pretrained(self):
for model_name in XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:

View File

@@ -593,6 +593,60 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
# xlnet cannot keep gradients in attentions or hidden states
return
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
# check hidden size
for i, layer_hidden_states in enumerate(iter_hidden_states):
# every 2nd tensor is from extra stream
if i % 2 != 0:
seq_len = 1
else:
# for first item dummy PAD token is appended so need one more
seq_len = (min_length + 1) if idx == 0 else min_length
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
self.assertEqual(layer_hidden_states.shape, expected_shape)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, attentions_item in enumerate(attentions):
for iter_attentions in attentions_item:
tgt_len = min_length
# for first item dummy PAD token is appended so need one more
if idx == 0:
tgt_len += 1
src_len = min_length + idx + 1
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions],
[expected_shape] * len(iter_attentions),
)
@slow
def test_model_from_pretrained(self):
for model_name in XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: