diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index a5eaf524ac..3cae206460 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -224,7 +224,7 @@ def train(args, train_dataset, model, tokenizer): model.zero_grad() train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) set_seed(args) # Added here for reproducibility (even between python 2 and 3) - for _ in train_iterator: + for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) @@ -279,6 +279,10 @@ def train(args, train_dataset, model, tokenizer): _rotate_checkpoints(args, checkpoint_prefix) + torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt')) + torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt')) + torch.save(epoch, os.path.join(output_dir, 'training_state.pt')) + if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break