From 506e5bb0c849b210b4284d975338963029f2c0d5 Mon Sep 17 00:00:00 2001 From: tholor Date: Fri, 11 Jan 2019 08:31:37 +0100 Subject: [PATCH] add do_lower_case arg and adjust model saving for lm finetuning. --- examples/run_lm_finetuning.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 39df2e99f8..35d1808bbc 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -461,6 +461,9 @@ def main(): parser.add_argument("--on_memory", action='store_true', help="Whether to load train samples into memory or use disk") + parser.add_argument("--do_lower_case", + action='store_true', + help="Whether to lower case the input text. True for uncased models, False for cased models.") parser.add_argument("--local_rank", type=int, default=-1, @@ -612,12 +615,12 @@ def main(): optimizer.zero_grad() global_step += 1 + # Save a trained model logger.info("** ** * Saving fine - tuned model ** ** * ") + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") - if n_gpu > 1: - torch.save(model.module.bert.state_dict(), output_model_file) - else: - torch.save(model.bert.state_dict(), output_model_file) + if args.do_train: + torch.save(model_to_save.state_dict(), output_model_file) def _truncate_seq_pair(tokens_a, tokens_b, max_length):