make style
This commit is contained in:
@@ -228,8 +228,10 @@ def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedToke
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
if args.model_name_or_path and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
if (
|
||||||
os.path.join(args.model_name_or_path, "scheduler.pt")
|
args.model_name_or_path
|
||||||
|
and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
|
||||||
|
and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
|
||||||
):
|
):
|
||||||
# Load in optimizer and scheduler states
|
# Load in optimizer and scheduler states
|
||||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||||
@@ -587,9 +589,7 @@ def main():
|
|||||||
if args.should_continue:
|
if args.should_continue:
|
||||||
sorted_checkpoints = _sorted_checkpoints(args)
|
sorted_checkpoints = _sorted_checkpoints(args)
|
||||||
if len(sorted_checkpoints) == 0:
|
if len(sorted_checkpoints) == 0:
|
||||||
raise ValueError(
|
raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
|
||||||
"Used --should_continue but no checkpoint was found in --output_dir."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
args.model_name_or_path = sorted_checkpoints[-1]
|
args.model_name_or_path = sorted_checkpoints[-1]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user