From 877ef2c6cae3059ff9307387baaed886139c5eff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 10 Oct 2019 10:02:18 +0200 Subject: [PATCH] override `from_pretrained` in Bert2Rnd In the seq2seq model we need to both load pretrained weights in the encoder and initialize the decoder randomly. Because the `from_pretrained` method defined in the base class relies on module names to assign weights, it would also initialize the decoder with pretrained weights. To avoid this we override the method to only initialize the encoder with pretrained weights. --- transformers/modeling_bert.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index fc698c772e..db8847f39e 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -1455,6 +1455,37 @@ class Bert2Rnd(BertPreTrainedModel): self.init_weights() + @classmethod + def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs): + """ Load the pretrained weights in the encoder. + + Since the decoder needs to be initialized with random weights, and the encoder with + pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel` + class. + """ + pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs) + + config = cls._load_config(pretrained_model_or_path, *model_args, **model_kwargs) + model = cls(config) + model.encoder = pretrained_encoder + + return model + + def _load_config(self, pretrained_model_name_or_path, *args, **kwargs): + config = kwargs.pop('config', None) + if config is None: + cache_dir = kwargs.pop('cache_dir', None) + force_download = kwargs.pop('force_download', False) + config, _ = self.config_class.from_pretrained( + pretrained_model_name_or_path, + *args, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + **kwargs + ) + return config + def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): encoder_outputs = self.encoder(input_ids, attention_mask=attention_mask,