From 27c1b656cca75efa0cc414d3bf4e6aacf24829de Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Tue, 7 Jan 2020 16:16:12 +0100 Subject: [PATCH] Fix error with global step in run_lm_finetuning.py --- examples/run_lm_finetuning.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index c2a0f8c08e..034ea59330 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -264,15 +264,19 @@ def train(args, train_dataset, model, tokenizer): steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): - # set global_step to gobal_step of last saved checkpoint from model path - global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) - epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) - steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) + try: + # set global_step to gobal_step of last saved checkpoint from model path + checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0] + global_step = int(checkpoint_suffix) + epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) + steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) - logger.info(" Continuing training from checkpoint, will skip to saved global_step") - logger.info(" Continuing training from epoch %d", epochs_trained) - logger.info(" Continuing training from global step %d", global_step) - logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(" Continuing training from epoch %d", epochs_trained) + logger.info(" Continuing training from global step %d", global_step) + logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) + except ValueError: + logger.info(" Starting fine-tuning.") tr_loss, logging_loss = 0.0, 0.0