Use saved optimizer and scheduler states if available
This commit is contained in:
committed by
Lysandre Debut
parent
a03fcf570d
commit
0eb973b0d9
@@ -188,6 +188,13 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
]
|
]
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||||
|
|
||||||
|
# Check if saved optimizer or scheduler states exist
|
||||||
|
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
|
||||||
|
# Load in optimizer and scheduler states
|
||||||
|
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
|
||||||
|
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
|
|||||||
Reference in New Issue
Block a user