override from_pretrained in Bert2Rnd
In the seq2seq model we need to both load pretrained weights in the encoder and initialize the decoder randomly. Because the `from_pretrained` method defined in the base class relies on module names to assign weights, it would also initialize the decoder with pretrained weights. To avoid this we override the method to only initialize the encoder with pretrained weights.
This commit is contained in:
@@ -1455,6 +1455,37 @@ class Bert2Rnd(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs):
|
||||
""" Load the pretrained weights in the encoder.
|
||||
|
||||
Since the decoder needs to be initialized with random weights, and the encoder with
|
||||
pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel`
|
||||
class.
|
||||
"""
|
||||
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
|
||||
|
||||
config = cls._load_config(pretrained_model_or_path, *model_args, **model_kwargs)
|
||||
model = cls(config)
|
||||
model.encoder = pretrained_encoder
|
||||
|
||||
return model
|
||||
|
||||
def _load_config(self, pretrained_model_name_or_path, *args, **kwargs):
|
||||
config = kwargs.pop('config', None)
|
||||
if config is None:
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
config, _ = self.config_class.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
*args,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
**kwargs
|
||||
)
|
||||
return config
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
encoder_outputs = self.encoder(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
||||
Reference in New Issue
Block a user