Fix: save model/model.module

This commit is contained in:
wangfei
2019-08-18 11:03:47 +08:00
parent 1ef41b8337
commit 856a63da4d
2 changed files with 4 additions and 2 deletions

View File

@@ -322,7 +322,8 @@ def main():
# Save a trained model # Save a trained model
if args.local_rank == -1 or torch.distributed.get_rank() == 0: if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logging.info("** ** * Saving fine-tuned model ** ** * ") logging.info("** ** * Saving fine-tuned model ** ** * ")
model.save_pretrained(args.output_dir) model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)

View File

@@ -610,7 +610,8 @@ def main():
# Save a trained model # Save a trained model
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
logger.info("** ** * Saving fine - tuned model ** ** * ") logger.info("** ** * Saving fine - tuned model ** ** * ")
model.save_pretrained(args.output_dir) model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)