[run_ner] Don't crash if fine-tuning local model that doesn't end with digit
This commit is contained in:
@@ -160,7 +160,10 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
|||||||
# Check if continuing training from a checkpoint
|
# Check if continuing training from a checkpoint
|
||||||
if os.path.exists(args.model_name_or_path):
|
if os.path.exists(args.model_name_or_path):
|
||||||
# set global_step to gobal_step of last saved checkpoint from model path
|
# set global_step to gobal_step of last saved checkpoint from model path
|
||||||
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
try:
|
||||||
|
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
||||||
|
except ValueError:
|
||||||
|
global_step = 0
|
||||||
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user