special edition script
This commit is contained in:
@@ -480,9 +480,9 @@ def main():
|
||||
|
||||
model = BertForSequenceClassification(bert_config, len(label_list))
|
||||
if args.init_checkpoint is not None:
|
||||
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||
model.to(device)
|
||||
|
||||
|
||||
if n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
@@ -575,7 +575,7 @@ def main():
|
||||
|
||||
eval_loss += tmp_eval_loss.item()
|
||||
eval_accuracy += tmp_eval_accuracy
|
||||
|
||||
|
||||
nb_eval_examples += input_ids.size(0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_examples #len(eval_dataloader)
|
||||
|
||||
Reference in New Issue
Block a user