diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index 57eccb6b33..c30e8abe1f 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -35,6 +35,7 @@ class EncoderDecoderModel(PreTrainedModel): class method for the encoder and `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` class method for the decoder. """ config_class = EncoderDecoderConfig + base_model_prefix = "encoder_decoder" def __init__( self, @@ -158,12 +159,26 @@ class EncoderDecoderModel(PreTrainedModel): ), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined" from .modeling_auto import AutoModelWithLMHead + if "config" not in kwargs_decoder: + from transformers import AutoConfig + + decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) + if decoder_config.is_decoder is False: + logger.info( + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + ) + decoder_config.is_decoder = True + + kwargs_decoder["config"] = decoder_config + + if kwargs_decoder["config"].is_decoder is False: + logger.warning( + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`" + ) + decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) - decoder.config.is_decoder = True - model = cls(encoder=encoder, decoder=decoder) - - return model + return cls(encoder=encoder, decoder=decoder) def forward( self, diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index a07c69b242..6130cb8804 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -22,6 +22,7 @@ from transformers import is_torch_available # TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented # for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest from .test_modeling_bert import BertModelTester +from .test_modeling_common import ids_tensor from .utils import require_torch, slow, torch_device @@ -331,3 +332,33 @@ class EncoderDecoderModelTest(unittest.TestCase): def test_real_bert_model_from_pretrained(self): model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased") self.assertIsNotNone(model) + + @slow + def test_real_bert_model_from_pretrained_has_cross_attention(self): + model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased") + self.assertTrue(hasattr(model.decoder.bert.encoder.layer[0], "crossattention")) + + @slow + def test_real_bert_model_save_load_from_pretrained(self): + model_2 = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased") + model_2.to(torch_device) + input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size) + decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size) + attention_mask = ids_tensor([13, 5], vocab_size=2) + with torch.no_grad(): + outputs = model_2(input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask,) + out_2 = outputs[0].cpu().numpy() + out_2[np.isnan(out_2)] = 0 + + with tempfile.TemporaryDirectory() as tmp_dirname: + model_2.save_pretrained(tmp_dirname) + model_1 = EncoderDecoderModel.from_pretrained(tmp_dirname) + model_1.to(torch_device) + + after_outputs = model_1( + input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, + ) + out_1 = after_outputs[0].cpu().numpy() + out_1[np.isnan(out_1)] = 0 + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5)