fixed init_weight
This commit is contained in:
@@ -418,7 +418,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
|
|||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
||||||
position_ids=None, head_mask=None):
|
position_ids=None, head_mask=None):
|
||||||
|
|||||||
@@ -1065,7 +1065,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
|||||||
self.sequence_summary = SequenceSummary(config)
|
self.sequence_summary = SequenceSummary(config)
|
||||||
self.logits_proj = nn.Linear(config.d_model, 1)
|
self.logits_proj = nn.Linear(config.d_model, 1)
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||||
mems=None, perm_mask=None, target_mapping=None,
|
mems=None, perm_mask=None, target_mapping=None,
|
||||||
|
|||||||
Reference in New Issue
Block a user