Applied patch to OpenAI GPT, RoBERTa, TransfoL, XLM and XLNet

This commit is contained in:
LysandreJik
2019-08-31 00:33:11 -04:00
parent bdb4409ed8
commit b6992b7b47
5 changed files with 27 additions and 41 deletions

View File

@@ -168,7 +168,7 @@ class RobertaModel(BertModel):
super(RobertaModel, self).__init__(config)
self.embeddings = RobertaEmbeddings(config)
self.apply(self.init_weights)
self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
if input_ids[:, 0].sum().item() != 0:
@@ -220,7 +220,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
self.roberta = RobertaModel(config)
self.lm_head = RobertaLMHead(config)
self.apply(self.init_weights)
self.init_weights()
self.tie_weights()
def tie_weights(self):