[GenerationOutputs] Fix GenerationOutputs Tests (#9443)

* fix generation models

* fix led

* fix docs

* add is_decoder

* fix last docstrings

* make style

* fix t5 cross attentions

* correct t5
This commit is contained in:
Patrick von Platen
2021-01-06 19:37:02 +01:00
committed by GitHub
parent 0c96262f7d
commit b8462b5b2a
9 changed files with 89 additions and 53 deletions

View File

@@ -522,6 +522,7 @@ class GenerationTesterMixin:
return
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
@@ -730,6 +731,7 @@ class GenerationTesterMixin:
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_beam, output_generate = self._beam_search_generate(
model=model,
@@ -962,12 +964,7 @@ class GenerationTesterMixin:
# Attentions
if config.is_encoder_decoder:
# encoder
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length)
self.assertIsInstance(output.encoder_attentions, tuple)
self.assertListEqual(
[layer_attentions.shape for layer_attentions in output.encoder_attentions],
[encoder_expected_shape] * len(output.encoder_attentions),
)
self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
# decoder
self._check_attentions_for_generate(
num_sequences_in_output,
@@ -993,11 +990,8 @@ class GenerationTesterMixin:
# Hidden States
if config.is_encoder_decoder:
# encoder
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
self.assertIsInstance(output.encoder_hidden_states, tuple)
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in output.encoder_hidden_states],
[encoder_expected_shape] * len(output.encoder_hidden_states),
self._check_encoder_hidden_states_for_generate(
output.encoder_hidden_states, batch_size, config, seq_length
)
# decoder
@@ -1052,6 +1046,14 @@ class GenerationTesterMixin:
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length)
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[layer_attentions.shape for layer_attentions in attentions],
[encoder_expected_shape] * len(attentions),
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
@@ -1071,6 +1073,14 @@ class GenerationTesterMixin:
[expected_shape] * len(iter_hidden_states),
)
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
[encoder_expected_shape] * len(hidden_states),
)
@require_torch
class UtilsFunctionsTest(unittest.TestCase):

View File

@@ -327,6 +327,32 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
# longformer cannot keep gradients in attentions or hidden states
return
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
# make sure tgt_length is padded
tgt_length = (
seq_length // config.attention_window[0] + (seq_length % config.attention_window[0] != 0)
) * config.attention_window[0]
encoder_expected_shape = (batch_size, config.num_attention_heads, tgt_length, seq_length)
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[layer_attentions.shape for layer_attentions in attentions],
[encoder_expected_shape] * len(attentions),
)
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
# make sure seq_length is padded
seq_length = (
seq_length // config.attention_window[0] + (seq_length % config.attention_window[0] != 0)
) * config.attention_window[0]
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
[encoder_expected_shape] * len(hidden_states),
)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True