diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 1d93aa4381..9bdbf9ca56 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -188,6 +188,13 @@ def train(args, train_dataset, model, tokenizer): ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) + + # Check if saved optimizer or scheduler states exist + if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')): + # Load in optimizer and scheduler states + optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt'))) + scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt'))) + if args.fp16: try: from apex import amp