diff --git a/examples/run_ner.py b/examples/run_ner.py index 1ab1236d94..86e74956df 100644 --- a/examples/run_ner.py +++ b/examples/run_ner.py @@ -85,6 +85,13 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) + + # Check if saved optimizer or scheduler states exist + if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')): + # Load in optimizer and scheduler states + optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt'))) + scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt'))) + if args.fp16: try: from apex import amp @@ -114,13 +121,33 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): logger.info(" Total optimization steps = %d", t_total) global_step = 0 + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + # Check if continuing training from a checkpoint + if os.path.exists(args.model_name_or_path): + # set global_step to gobal_step of last saved checkpoint from model path + global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0]) + epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) + steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(" Continuing training from epoch %d", epochs_trained) + logger.info(" Continuing training from global step %d", global_step) + logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) + tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() - train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) + train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) set_seed(args) # Added here for reproductibility (even between python 2 and 3) for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + continue + model.train() batch = tuple(t.to(args.device) for t in batch) inputs = {"input_ids": batch[0], @@ -172,9 +199,15 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): os.makedirs(output_dir) model_to_save = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) + torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt')) + torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt')) + logger.info("Saving optimizer and scheduler states to %s", output_dir) + if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break