Merge pull request #248 from JoeDumoulin/squad1.1-fix
fix prediction on run-squad.py example
This commit is contained in:
@@ -701,7 +701,7 @@ def main():
|
|||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||||
help="Total number of training epochs to perform.")
|
help="Total number of training epochs to perform.")
|
||||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
|
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
|
||||||
"of training.")
|
"of training.")
|
||||||
parser.add_argument("--n_best_size", default=20, type=int,
|
parser.add_argument("--n_best_size", default=20, type=int,
|
||||||
help="The total number of n-best predictions to generate in the nbest_predictions.json "
|
help="The total number of n-best predictions to generate in the nbest_predictions.json "
|
||||||
@@ -915,9 +915,12 @@ def main():
|
|||||||
if args.do_train:
|
if args.do_train:
|
||||||
torch.save(model_to_save.state_dict(), output_model_file)
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
|
|
||||||
# Load a trained model that you have fine-tuned
|
# Load a trained model that you have fine-tuned
|
||||||
model_state_dict = torch.load(output_model_file)
|
model_state_dict = torch.load(output_model_file)
|
||||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
||||||
|
else:
|
||||||
|
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
|
|||||||
Reference in New Issue
Block a user