@@ -322,8 +322,7 @@ def main():
|
||||
# Save a trained model
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
model.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
|
||||
|
||||
@@ -610,8 +610,7 @@ def main():
|
||||
# Save a trained model
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
model.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user