EncoderDecoderConfigs should not create new objects (#11300)
* removes the creation of separate config objects and uses the existing ones instead+overwrite resize_token_embeddings from parent class because it is not working for the EncoderDecoderModel * rollback to current version of the huggingface master branch * reworked version that ties the encoder and decoder config of the parent encoderdecoder instance * overwrite of resize_token_embeddings throws an error now * review comment suggestion Co-authored-by: Suraj Patil <surajp815@gmail.com> * implemented warning in case encoderdecoder is created with differing configs of encoderdecoderconfig and decoderconfig or encoderconfig * added test to avoid diverging configs of wrapper class and wrapped classes * Update src/transformers/models/encoder_decoder/modeling_encoder_decoder.py * make style Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -34,6 +34,7 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
BartForCausalLM,
|
||||
BertGenerationDecoder,
|
||||
@@ -884,3 +885,38 @@ class BartEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
|
||||
def test_encoder_decoder_model_shared_weights(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class EncoderDecoderModelTest(unittest.TestCase):
|
||||
def get_from_encoderdecoder_pretrained_model(self):
|
||||
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
|
||||
|
||||
def get_decoder_config(self):
|
||||
config = AutoConfig.from_pretrained("bert-base-uncased")
|
||||
config.is_decoder = True
|
||||
config.add_cross_attention = True
|
||||
return config
|
||||
|
||||
def get_encoderdecoder_model(self):
|
||||
return EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
||||
|
||||
def get_encoder_decoder_models(self):
|
||||
encoder_model = BertModel.from_pretrained("bert-base-uncased")
|
||||
decoder_model = BertLMHeadModel.from_pretrained("bert-base-uncased", config=self.get_decoder_config())
|
||||
return {"encoder": encoder_model, "decoder": decoder_model}
|
||||
|
||||
def _check_configuration_tie(self, model):
|
||||
assert id(model.decoder.config) == id(model.config.decoder)
|
||||
assert id(model.encoder.config) == id(model.config.encoder)
|
||||
|
||||
@slow
|
||||
def test_configuration_tie(self):
|
||||
model = self.get_from_encoderdecoder_pretrained_model()
|
||||
self._check_configuration_tie(model)
|
||||
|
||||
model = EncoderDecoderModel(**self.get_encoder_decoder_models())
|
||||
self._check_configuration_tie(model)
|
||||
|
||||
model = self.get_encoderdecoder_model()
|
||||
self._check_configuration_tie(model)
|
||||
|
||||
Reference in New Issue
Block a user