From aa90e0c36adc0034ece203c857d0d993c82ae65a Mon Sep 17 00:00:00 2001 From: joe dumoulin Date: Fri, 1 Feb 2019 10:15:44 -0800 Subject: [PATCH] fix prediction on run-squad.py example --- examples/run_squad.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index e2a3f9c924..0cb63edf0e 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -706,7 +706,7 @@ def main(): parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--warmup_proportion", default=0.1, type=float, - help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% " + help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% " "of training.") parser.add_argument("--n_best_size", default=20, type=int, help="The total number of n-best predictions to generate in the nbest_predictions.json " @@ -919,9 +919,12 @@ def main(): if args.do_train: torch.save(model_to_save.state_dict(), output_model_file) - # Load a trained model that you have fine-tuned - model_state_dict = torch.load(output_model_file) - model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict) + # Load a trained model that you have fine-tuned + model_state_dict = torch.load(output_model_file) + model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict) + else: + model = BertForQuestionAnswering.from_pretrained(args.bert_model) + model.to(device) if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):