adding save checkpoint and loading in examples

This commit is contained in:
thomwolf
2018-12-09 17:04:23 -05:00
parent 85fff78c2d
commit b3caec5a56
2 changed files with 6 additions and 2 deletions

View File

@@ -329,7 +329,7 @@ def main():
default=None,
type=str,
required=True,
help="The output directory where the model checkpoints will be written.")
help="The output directory where the model predictions and checkpoints will be written.")
## Other parameters
parser.add_argument("--max_seq_length",
@@ -593,6 +593,10 @@ def main():
'global_step': global_step,
'loss': tr_loss/nb_tr_steps}
model_to_save = model.module if hasattr(model, 'module') else model
raise NotImplementedError # TODO add save of the configuration file and vocabulary file also ?
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
torch.save(model_to_save, output_model_file)
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")