diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 2b1c83dc99..d322bbb091 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -814,7 +814,7 @@ class RobertaForCausalLM(RobertaPreTrainedModel): logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - cross_attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, ) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index d446eea76b..42205dcf64 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -300,6 +300,9 @@ class EncoderDecoderMixin: labels, **kwargs ): + # make the decoder inputs a different shape from the encoder inputs to harden the test + decoder_input_ids = decoder_input_ids[:, :-1] + decoder_attention_mask = decoder_attention_mask[:, :-1] encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) @@ -314,9 +317,8 @@ class EncoderDecoderMixin: encoder_attentions = outputs_encoder_decoder["encoder_attentions"] self.assertEqual(len(encoder_attentions), config.num_hidden_layers) - self.assertListEqual( - list(encoder_attentions[0].shape[-3:]), - [config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]], + self.assertEqual( + encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]) ) decoder_attentions = outputs_encoder_decoder["decoder_attentions"] @@ -327,20 +329,20 @@ class EncoderDecoderMixin: ) self.assertEqual(len(decoder_attentions), num_decoder_layers) - self.assertListEqual( - list(decoder_attentions[0].shape[-3:]), - [decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]], + self.assertEqual( + decoder_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]), ) cross_attentions = outputs_encoder_decoder["cross_attentions"] self.assertEqual(len(cross_attentions), num_decoder_layers) - cross_attention_input_seq_len = input_ids.shape[-1] * ( + cross_attention_input_seq_len = decoder_input_ids.shape[-1] * ( 1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0) ) - self.assertListEqual( - list(cross_attentions[0].shape[-3:]), - [decoder_config.num_attention_heads, cross_attention_input_seq_len, decoder_input_ids.shape[-1]], + self.assertEqual( + cross_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]), ) def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):