From a9debaca3dfd3d12c3acf896e3e2272d0a087257 Mon Sep 17 00:00:00 2001 From: erenup Date: Mon, 16 Sep 2019 19:55:24 +0800 Subject: [PATCH] fixed init_weight --- pytorch_transformers/modeling_roberta.py | 2 +- pytorch_transformers/modeling_xlnet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_transformers/modeling_roberta.py b/pytorch_transformers/modeling_roberta.py index 98bfa4202a..2b64893c2e 100644 --- a/pytorch_transformers/modeling_roberta.py +++ b/pytorch_transformers/modeling_roberta.py @@ -418,7 +418,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel): self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, 1) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, position_ids=None, head_mask=None): diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index 91390b9d6b..fa65d83b0e 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -1065,7 +1065,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): self.sequence_summary = SequenceSummary(config) self.logits_proj = nn.Linear(config.d_model, 1) - self.apply(self.init_weights) + self.init_weights() def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,