updating examples

This commit is contained in:
thomwolf
2019-06-19 13:23:20 +02:00
parent 68ab9599ce
commit dc8e0019b7
5 changed files with 212 additions and 42 deletions

View File

@@ -366,7 +366,7 @@ def main():
output_args_file = os.path.join(args.output_dir, 'training_args.bin')
torch.save(args, output_args_file)
else:
model = BertForSequenceClassification.from_pretrained(args.bert_model)
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
model.to(device)