fixed init_weight

This commit is contained in:
erenup
2019-09-16 19:55:24 +08:00
parent 982f181aa7
commit a9debaca3d
2 changed files with 2 additions and 2 deletions

View File

@@ -418,7 +418,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1) self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_weights) self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None): position_ids=None, head_mask=None):

View File

@@ -1065,7 +1065,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
self.sequence_summary = SequenceSummary(config) self.sequence_summary = SequenceSummary(config)
self.logits_proj = nn.Linear(config.d_model, 1) self.logits_proj = nn.Linear(config.d_model, 1)
self.apply(self.init_weights) self.init_weights()
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, mems=None, perm_mask=None, target_mapping=None,