From b3caec5a5662f591a1f148bf34ba4f853be514e2 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sun, 9 Dec 2018 17:04:23 -0500 Subject: [PATCH] adding save checkpoint and loading in examples --- examples/run_classifier.py | 6 +++++- examples/run_squad.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index a531ea5725..f18c5489ba 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -329,7 +329,7 @@ def main(): default=None, type=str, required=True, - help="The output directory where the model checkpoints will be written.") + help="The output directory where the model predictions and checkpoints will be written.") ## Other parameters parser.add_argument("--max_seq_length", @@ -593,6 +593,10 @@ def main(): 'global_step': global_step, 'loss': tr_loss/nb_tr_steps} + model_to_save = model.module if hasattr(model, 'module') else model + raise NotImplementedError # TODO add save of the configuration file and vocabulary file also ? + output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") + torch.save(model_to_save, output_model_file) output_eval_file = os.path.join(args.output_dir, "eval_results.txt") with open(output_eval_file, "w") as writer: logger.info("***** Eval results *****") diff --git a/examples/run_squad.py b/examples/run_squad.py index b96fcece37..b0668b38d8 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -690,7 +690,7 @@ def main(): help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") parser.add_argument("--output_dir", default=None, type=str, required=True, - help="The output directory where the model checkpoints will be written.") + help="The output directory where the model checkpoints and predictions will be written.") ## Other parameters parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json")