[All Seq2Seq model + CLM models that can be used with EncoderDecoder] Add cross-attention weights to outputs (#8071)

* Output cross-attention with decoder attention output

* Update src/transformers/modeling_bert.py

* add cross-attention for t5 and bart as well

* fix tests

* correct typo in docs

* add sylvains and sams comments

* correct typo

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Yossi Synett
2020-11-06 18:34:48 +00:00
committed by GitHub
parent 30f2507a07
commit bc0d26d1de
16 changed files with 653 additions and 85 deletions

View File

@@ -253,9 +253,7 @@ class ModelTesterMixin:
out_len = len(outputs)
if self.is_encoder_decoder:
correct_outlen = (
self.model_tester.base_model_out_len if hasattr(self.model_tester, "base_model_out_len") else 4
)
correct_outlen = 5
# loss is at first position
if "labels" in inputs_dict:
@@ -266,6 +264,7 @@ class ModelTesterMixin:
self.assertEqual(out_len, correct_outlen)
# decoder attentions
decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
@@ -274,6 +273,19 @@ class ModelTesterMixin:
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
# cross attentions
cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
decoder_seq_length,
encoder_key_length,
],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True