EncoderDecoderConfigs should not create new objects (#11300)

* removes the creation of separate config objects and uses the existing ones instead+overwrite resize_token_embeddings from parent class because it is not working for the EncoderDecoderModel

* rollback to current version of the huggingface master branch

* reworked version that ties the encoder and decoder config of the parent encoderdecoder instance

* overwrite of resize_token_embeddings throws an error now

* review comment suggestion

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* implemented warning in case encoderdecoder is created with differing configs of encoderdecoderconfig and decoderconfig or encoderconfig

* added test to avoid diverging configs of wrapper class and wrapped classes

* Update src/transformers/models/encoder_decoder/modeling_encoder_decoder.py

* make style

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
cronoik
2021-04-25 11:45:46 +02:00
committed by GitHub
parent f45cb66bf6
commit 35cd8eed88
2 changed files with 57 additions and 0 deletions

View File

@@ -34,6 +34,7 @@ if is_torch_available():
import torch
from transformers import (
AutoConfig,
AutoTokenizer,
BartForCausalLM,
BertGenerationDecoder,
@@ -884,3 +885,38 @@ class BartEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def test_encoder_decoder_model_shared_weights(self):
pass
@require_torch
class EncoderDecoderModelTest(unittest.TestCase):
def get_from_encoderdecoder_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
def get_decoder_config(self):
config = AutoConfig.from_pretrained("bert-base-uncased")
config.is_decoder = True
config.add_cross_attention = True
return config
def get_encoderdecoder_model(self):
return EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
def get_encoder_decoder_models(self):
encoder_model = BertModel.from_pretrained("bert-base-uncased")
decoder_model = BertLMHeadModel.from_pretrained("bert-base-uncased", config=self.get_decoder_config())
return {"encoder": encoder_model, "decoder": decoder_model}
def _check_configuration_tie(self, model):
assert id(model.decoder.config) == id(model.config.decoder)
assert id(model.encoder.config) == id(model.config.encoder)
@slow
def test_configuration_tie(self):
model = self.get_from_encoderdecoder_pretrained_model()
self._check_configuration_tie(model)
model = EncoderDecoderModel(**self.get_encoder_decoder_models())
self._check_configuration_tie(model)
model = self.get_encoderdecoder_model()
self._check_configuration_tie(model)