Fix incorrect output shapes for TF/PT LED (#13882)
* Fix issues with LED model * Style pass * Bugfixes * correct attentions as well Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -126,9 +126,7 @@ class LEDModelTester:
|
||||
|
||||
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
|
||||
# the `test_attention_outputs` and `test_hidden_states_output` tests
|
||||
self.encoder_seq_length = (
|
||||
self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window
|
||||
)
|
||||
self.encoder_seq_length = self.seq_length
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -354,32 +352,6 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
# longformer cannot keep gradients in attentions or hidden states
|
||||
return
|
||||
|
||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
||||
# make sure tgt_length is padded
|
||||
tgt_length = (
|
||||
seq_length // config.attention_window[0] + (seq_length % config.attention_window[0] != 0)
|
||||
) * config.attention_window[0]
|
||||
|
||||
encoder_expected_shape = (batch_size, config.num_attention_heads, tgt_length, seq_length)
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_attentions.shape for layer_attentions in attentions],
|
||||
[encoder_expected_shape] * len(attentions),
|
||||
)
|
||||
|
||||
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
|
||||
# make sure seq_length is padded
|
||||
seq_length = (
|
||||
seq_length // config.attention_window[0] + (seq_length % config.attention_window[0] != 0)
|
||||
) * config.attention_window[0]
|
||||
|
||||
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
|
||||
[encoder_expected_shape] * len(hidden_states),
|
||||
)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
Reference in New Issue
Block a user