[All Seq2Seq model + CLM models that can be used with EncoderDecoder] Add cross-attention weights to outputs (#8071)
* Output cross-attention with decoder attention output * Update src/transformers/modeling_bert.py * add cross-attention for t5 and bart as well * fix tests * correct typo in docs * add sylvains and sams comments * correct typo Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -253,9 +253,7 @@ class ModelTesterMixin:
|
||||
out_len = len(outputs)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
correct_outlen = (
|
||||
self.model_tester.base_model_out_len if hasattr(self.model_tester, "base_model_out_len") else 4
|
||||
)
|
||||
correct_outlen = 5
|
||||
|
||||
# loss is at first position
|
||||
if "labels" in inputs_dict:
|
||||
@@ -266,6 +264,7 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
# decoder attentions
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
@@ -274,6 +273,19 @@ class ModelTesterMixin:
|
||||
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||
)
|
||||
|
||||
# cross attentions
|
||||
cross_attentions = outputs.cross_attentions
|
||||
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.num_attention_heads,
|
||||
decoder_seq_length,
|
||||
encoder_key_length,
|
||||
],
|
||||
)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
|
||||
Reference in New Issue
Block a user