Fix error with global step in run_lm_finetuning.py
This commit is contained in:
@@ -264,15 +264,19 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
steps_trained_in_current_epoch = 0
|
steps_trained_in_current_epoch = 0
|
||||||
# Check if continuing training from a checkpoint
|
# Check if continuing training from a checkpoint
|
||||||
if os.path.exists(args.model_name_or_path):
|
if os.path.exists(args.model_name_or_path):
|
||||||
# set global_step to gobal_step of last saved checkpoint from model path
|
try:
|
||||||
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
# set global_step to gobal_step of last saved checkpoint from model path
|
||||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
|
||||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
global_step = int(checkpoint_suffix)
|
||||||
|
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
|
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
|
|
||||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||||
logger.info(" Continuing training from global step %d", global_step)
|
logger.info(" Continuing training from global step %d", global_step)
|
||||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||||
|
except ValueError:
|
||||||
|
logger.info(" Starting fine-tuning.")
|
||||||
|
|
||||||
tr_loss, logging_loss = 0.0, 0.0
|
tr_loss, logging_loss = 0.0, 0.0
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user