From f8e98d67793341f8955634c942af1af579f097dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 11 Oct 2019 16:48:11 +0200 Subject: [PATCH] load pretrained embeddings in Bert decoder In Rothe et al.'s "Leveraging Pre-trained Checkpoints for Sequence Generation Tasks", Bert2Bert is initialized with pre-trained weights for the encoder, and only pre-trained embeddings for the decoder. The current version of the code completely randomizes the weights of the decoder. We write a custom function to initiliaze the weights of the decoder; we first initialize the decoder with the weights and then randomize everything but the embeddings. --- transformers/modeling_bert.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index bce7972315..03559ad26c 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -1348,15 +1348,14 @@ class Bert2Rnd(BertPreTrainedModel): self.encoder = BertModel(config) self.decoder = BertDecoderModel(config) - self.init_weights() - @classmethod def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs): """ Load the pretrained weights in the encoder. - Since the decoder needs to be initialized with random weights, and the encoder with - pretrained weights we need to override the `from_pretrained` method of the base `PreTrainedModel` - class. + The encoder of `Bert2Rand` is initialized with pretrained weights; the + weights of the decoder are initialized at random except the embeddings + which are initialized with the pretrained embeddings. We thus need to override + the base class' `from_pretrained` method. """ # Load the configuration @@ -1374,10 +1373,26 @@ class Bert2Rnd(BertPreTrainedModel): ) model = cls(config) - # The encoder is loaded with pretrained weights + # We load the encoder with pretrained weights pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs) model.encoder = pretrained_encoder + # We load the decoder with pretrained weights and then randomize all weights but embeddings-related one. + def randomize_decoder_weights(module): + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + pretrained_decoder = BertDecoderModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs) + pretrained_decoder.apply(randomize_decoder_weights) + model.decoder = pretrained_decoder + return model def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): @@ -1386,11 +1401,9 @@ class Bert2Rnd(BertPreTrainedModel): token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask) - encoder_output = encoder_outputs[0] - decoder_outputs = self.decoder(input_ids, - encoder_output, + encoder_outputs[0], token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask) - return decoder_outputs[0] + return decoder_outputs