From 8cca875569dc74d01f2e207b47300ca728ce5164 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 5 Jun 2020 23:16:37 +0200 Subject: [PATCH] [EncoderDecoderConfig] automatically set decoder config to decoder (#4809) * automatically set decoder config to decoder * add more tests --- .../configuration_encoder_decoder.py | 3 ++ tests/test_modeling_encoder_decoder.py | 38 ++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/transformers/configuration_encoder_decoder.py b/src/transformers/configuration_encoder_decoder.py index 2fafbebb8d..2107deb4ca 100644 --- a/src/transformers/configuration_encoder_decoder.py +++ b/src/transformers/configuration_encoder_decoder.py @@ -85,6 +85,9 @@ class EncoderDecoderConfig(PretrainedConfig): Returns: :class:`EncoderDecoderConfig`: An instance of a configuration object """ + logger.info("Set `config.is_decoder=True` for decoder_config") + decoder_config.is_decoder = True + return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict()) def to_dict(self): diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index ce3b69d91e..caf2891aab 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -27,7 +27,7 @@ from .utils import require_torch, slow, torch_device if is_torch_available(): - from transformers import BertModel, BertForMaskedLM, EncoderDecoderModel + from transformers import BertModel, BertForMaskedLM, EncoderDecoderModel, EncoderDecoderConfig import numpy as np import torch @@ -74,6 +74,36 @@ class EncoderDecoderModelTest(unittest.TestCase): "labels": decoder_token_labels, } + def create_and_check_bert_encoder_decoder_model_from_pretrained_configs( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) + self.assertTrue(encoder_decoder_config.decoder.is_decoder) + + enc_dec_model = EncoderDecoderModel(encoder_decoder_config) + enc_dec_model.to(torch_device) + enc_dec_model.eval() + + self.assertTrue(enc_dec_model.config.is_encoder_decoder) + + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))) + self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,))) + def create_and_check_bert_encoder_decoder_model( self, config, @@ -88,6 +118,8 @@ class EncoderDecoderModelTest(unittest.TestCase): encoder_model = BertModel(config) decoder_model = BertForMaskedLM(decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + self.assertTrue(enc_dec_model.config.decoder.is_decoder) + self.assertTrue(enc_dec_model.config.is_encoder_decoder) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, @@ -304,6 +336,10 @@ class EncoderDecoderModelTest(unittest.TestCase): input_ids_dict = self.prepare_config_and_inputs_bert() self.create_and_check_bert_encoder_decoder_model(**input_ids_dict) + def test_bert_encoder_decoder_model_from_pretrained_configs(self): + input_ids_dict = self.prepare_config_and_inputs_bert() + self.create_and_check_bert_encoder_decoder_model_from_pretrained_configs(**input_ids_dict) + def test_bert_encoder_decoder_model_from_pretrained(self): input_ids_dict = self.prepare_config_and_inputs_bert() self.create_and_check_bert_encoder_decoder_model_from_pretrained(**input_ids_dict)