[GenerationOutputs] Fix GenerationOutputs Tests (#9443)
* fix generation models * fix led * fix docs * add is_decoder * fix last docstrings * make style * fix t5 cross attentions * correct t5
This commit is contained in:
committed by
GitHub
parent
0c96262f7d
commit
b8462b5b2a
@@ -522,6 +522,7 @@ class GenerationTesterMixin:
|
||||
return
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_greedy, output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
@@ -730,6 +731,7 @@ class GenerationTesterMixin:
|
||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_beam, output_generate = self._beam_search_generate(
|
||||
model=model,
|
||||
@@ -962,12 +964,7 @@ class GenerationTesterMixin:
|
||||
# Attentions
|
||||
if config.is_encoder_decoder:
|
||||
# encoder
|
||||
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length)
|
||||
self.assertIsInstance(output.encoder_attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_attentions.shape for layer_attentions in output.encoder_attentions],
|
||||
[encoder_expected_shape] * len(output.encoder_attentions),
|
||||
)
|
||||
self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
|
||||
# decoder
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
@@ -993,11 +990,8 @@ class GenerationTesterMixin:
|
||||
# Hidden States
|
||||
if config.is_encoder_decoder:
|
||||
# encoder
|
||||
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
|
||||
self.assertIsInstance(output.encoder_hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in output.encoder_hidden_states],
|
||||
[encoder_expected_shape] * len(output.encoder_hidden_states),
|
||||
self._check_encoder_hidden_states_for_generate(
|
||||
output.encoder_hidden_states, batch_size, config, seq_length
|
||||
)
|
||||
|
||||
# decoder
|
||||
@@ -1052,6 +1046,14 @@ class GenerationTesterMixin:
|
||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
||||
)
|
||||
|
||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
||||
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length)
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_attentions.shape for layer_attentions in attentions],
|
||||
[encoder_expected_shape] * len(attentions),
|
||||
)
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
@@ -1071,6 +1073,14 @@ class GenerationTesterMixin:
|
||||
[expected_shape] * len(iter_hidden_states),
|
||||
)
|
||||
|
||||
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
|
||||
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
|
||||
[encoder_expected_shape] * len(hidden_states),
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class UtilsFunctionsTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user