special edition script

This commit is contained in:
thomwolf
2018-11-03 19:06:15 +01:00
parent 25f73add07
commit 04287a4d68
3 changed files with 108 additions and 4 deletions

View File

@@ -480,9 +480,9 @@ def main():
model = BertForSequenceClassification(bert_config, len(label_list))
if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device)
if n_gpu > 1:
model = torch.nn.DataParallel(model)
@@ -575,7 +575,7 @@ def main():
eval_loss += tmp_eval_loss.item()
eval_accuracy += tmp_eval_accuracy
nb_eval_examples += input_ids.size(0)
eval_loss = eval_loss / nb_eval_examples #len(eval_dataloader)