From 856a63da4d1f0f302633dc73e2d4a1f698bbafda Mon Sep 17 00:00:00 2001 From: wangfei <1140554608@qq.com> Date: Sun, 18 Aug 2019 11:03:47 +0800 Subject: [PATCH] Fix: save model/model.module --- examples/lm_finetuning/finetune_on_pregenerated.py | 3 ++- examples/lm_finetuning/simple_lm_finetuning.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/lm_finetuning/finetune_on_pregenerated.py b/examples/lm_finetuning/finetune_on_pregenerated.py index 7c40342f18..eefa56c824 100644 --- a/examples/lm_finetuning/finetune_on_pregenerated.py +++ b/examples/lm_finetuning/finetune_on_pregenerated.py @@ -322,7 +322,8 @@ def main(): # Save a trained model if args.local_rank == -1 or torch.distributed.get_rank() == 0: logging.info("** ** * Saving fine-tuned model ** ** * ") - model.save_pretrained(args.output_dir) + model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training + model_to_save.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) diff --git a/examples/lm_finetuning/simple_lm_finetuning.py b/examples/lm_finetuning/simple_lm_finetuning.py index 25333de0ed..9633640faf 100644 --- a/examples/lm_finetuning/simple_lm_finetuning.py +++ b/examples/lm_finetuning/simple_lm_finetuning.py @@ -610,7 +610,8 @@ def main(): # Save a trained model if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): logger.info("** ** * Saving fine - tuned model ** ** * ") - model.save_pretrained(args.output_dir) + model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training + model_to_save.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)