Fix TF T5/LED missing cross attn in retrun values (#15511)

* add cross attn to outputs

* add cross attn to outputs for TFLED

* add undo padding

* remove unused import

* fix style

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-02-07 17:41:48 +01:00
committed by GitHub
parent 6775b211b6
commit 131e258411
3 changed files with 39 additions and 11 deletions

View File

@@ -322,7 +322,7 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
self.assertEqual(len(global_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, seq_length],
[self.model_tester.num_attention_heads, seq_length, seq_length],
)
self.assertListEqual(
list(global_attentions[0].shape[-3:]),