override from_pretrained in Bert2Rnd

In the seq2seq model we need to both load pretrained weights in the
encoder and initialize the decoder randomly. Because the
`from_pretrained` method defined in the base class relies on module
names to assign weights, it would also initialize the decoder with
pretrained weights. To avoid this we override the method to only
initialize the encoder with pretrained weights.
This commit is contained in:
Rémi Louf
2019-10-10 10:02:18 +02:00
parent 851ef592c5
commit 877ef2c6ca

View File

@@ -1455,6 +1455,37 @@ class Bert2Rnd(BertPreTrainedModel):
self.init_weights()
@classmethod
def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs):
""" Load the pretrained weights in the encoder.
Since the decoder needs to be initialized with random weights, and the encoder with
pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel`
class.
"""
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
config = cls._load_config(pretrained_model_or_path, *model_args, **model_kwargs)
model = cls(config)
model.encoder = pretrained_encoder
return model
def _load_config(self, pretrained_model_name_or_path, *args, **kwargs):
config = kwargs.pop('config', None)
if config is None:
cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
config, _ = self.config_class.from_pretrained(
pretrained_model_name_or_path,
*args,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
**kwargs
)
return config
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
encoder_outputs = self.encoder(input_ids,
attention_mask=attention_mask,