Fix bug in x-attentions output for roberta and harden test to catch it (#8660)
This commit is contained in:
@@ -300,6 +300,9 @@ class EncoderDecoderMixin:
|
||||
labels,
|
||||
**kwargs
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
@@ -314,9 +317,8 @@ class EncoderDecoderMixin:
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(encoder_attentions[0].shape[-3:]),
|
||||
[config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]],
|
||||
self.assertEqual(
|
||||
encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1])
|
||||
)
|
||||
|
||||
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
|
||||
@@ -327,20 +329,20 @@ class EncoderDecoderMixin:
|
||||
)
|
||||
self.assertEqual(len(decoder_attentions), num_decoder_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]],
|
||||
self.assertEqual(
|
||||
decoder_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
cross_attentions = outputs_encoder_decoder["cross_attentions"]
|
||||
self.assertEqual(len(cross_attentions), num_decoder_layers)
|
||||
|
||||
cross_attention_input_seq_len = input_ids.shape[-1] * (
|
||||
cross_attention_input_seq_len = decoder_input_ids.shape[-1] * (
|
||||
1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
[decoder_config.num_attention_heads, cross_attention_input_seq_len, decoder_input_ids.shape[-1]],
|
||||
self.assertEqual(
|
||||
cross_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user