[GenerationOutputs] Fix GenerationOutputs Tests (#9443)
* fix generation models * fix led * fix docs * add is_decoder * fix last docstrings * make style * fix t5 cross attentions * correct t5
This commit is contained in:
committed by
GitHub
parent
0c96262f7d
commit
b8462b5b2a
@@ -327,6 +327,32 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
# longformer cannot keep gradients in attentions or hidden states
|
||||
return
|
||||
|
||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
||||
# make sure tgt_length is padded
|
||||
tgt_length = (
|
||||
seq_length // config.attention_window[0] + (seq_length % config.attention_window[0] != 0)
|
||||
) * config.attention_window[0]
|
||||
|
||||
encoder_expected_shape = (batch_size, config.num_attention_heads, tgt_length, seq_length)
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_attentions.shape for layer_attentions in attentions],
|
||||
[encoder_expected_shape] * len(attentions),
|
||||
)
|
||||
|
||||
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
|
||||
# make sure seq_length is padded
|
||||
seq_length = (
|
||||
seq_length // config.attention_window[0] + (seq_length % config.attention_window[0] != 0)
|
||||
) * config.attention_window[0]
|
||||
|
||||
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
|
||||
[encoder_expected_shape] * len(hidden_states),
|
||||
)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
Reference in New Issue
Block a user