[All Seq2Seq model + CLM models that can be used with EncoderDecoder] Add cross-attention weights to outputs (#8071)

* Output cross-attention with decoder attention output

* Update src/transformers/modeling_bert.py

* add cross-attention for t5 and bart as well

* fix tests

* correct typo in docs

* add sylvains and sams comments

* correct typo

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Yossi Synett
2020-11-06 18:34:48 +00:00
committed by GitHub
parent 30f2507a07
commit bc0d26d1de
16 changed files with 653 additions and 85 deletions

View File

@@ -292,6 +292,62 @@ class EncoderDecoderMixin:
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
)
def check_encoder_decoder_model_output_attentions(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
return_dict=True,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
self.assertListEqual(
list(encoder_attentions[0].shape[-3:]),
[config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]],
)
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]],
)
cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
cross_attention_input_seq_len = input_ids.shape[-1] * (
1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[decoder_config.num_attention_heads, cross_attention_input_seq_len, decoder_input_ids.shape[-1]],
)
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
@@ -413,6 +469,10 @@ class EncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_labels(**input_ids_dict)
def test_encoder_decoder_model_output_attentions(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
def test_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict)