Fix incorrect output shapes for TF/PT LED (#13882)

* Fix issues with LED model

* Style pass

* Bugfixes

* correct attentions as well

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Matt
2021-10-07 17:30:15 +01:00
committed by GitHub
parent 5f34163b88
commit 61cf2ea9c0
3 changed files with 13 additions and 32 deletions

View File

@@ -126,9 +126,7 @@ class LEDModelTester:
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests
self.encoder_seq_length = (
self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window
)
self.encoder_seq_length = self.seq_length
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -354,32 +352,6 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
# longformer cannot keep gradients in attentions or hidden states
return
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
# make sure tgt_length is padded
tgt_length = (
seq_length // config.attention_window[0] + (seq_length % config.attention_window[0] != 0)
) * config.attention_window[0]
encoder_expected_shape = (batch_size, config.num_attention_heads, tgt_length, seq_length)
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[layer_attentions.shape for layer_attentions in attentions],
[encoder_expected_shape] * len(attentions),
)
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
# make sure seq_length is padded
seq_length = (
seq_length // config.attention_window[0] + (seq_length % config.attention_window[0] != 0)
) * config.attention_window[0]
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
[encoder_expected_shape] * len(hidden_states),
)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True