From 9626e0458c20b61c18c9564ecc4d1261a4a66e50 Mon Sep 17 00:00:00 2001 From: Bilal Khan Date: Wed, 27 Nov 2019 20:00:16 -0600 Subject: [PATCH] Add functionality to continue training from last saved global_step --- examples/run_lm_finetuning.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 5e7683b85d..172d4e20e2 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -223,17 +223,37 @@ def train(args, train_dataset, model, tokenizer): logger.info(" Total optimization steps = %d", t_total) global_step = 0 + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + # Check if continuing training from a checkpoint + if os.path.exists(args.model_name_or_path): + # set global_step to gobal_step of last saved checkpoint from model path + global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0]) + epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) + steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(" Continuing training from epoch %d", epochs_trained) + logger.info(" Continuing training from global step %d", global_step) + logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) + tr_loss, logging_loss = 0.0, 0.0 model_to_resize = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_resize.resize_token_embeddings(len(tokenizer)) model.zero_grad() - train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) + train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) set_seed(args) # Added here for reproducibility (even between python 2 and 3) for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + continue + inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) inputs = inputs.to(args.device) labels = labels.to(args.device)