Fix lr_scheduler in no_trainer training scripts (#27872)

* Fix lr_scheduler

* Fix lr scheduler
This commit is contained in:
bofeng huang
2024-01-22 15:22:18 +01:00
committed by GitHub
parent 692c3c6b73
commit deb2b59073
9 changed files with 18 additions and 18 deletions

View File

@@ -438,8 +438,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
name=args.lr_scheduler_type, name=args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.

View File

@@ -626,8 +626,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
name=args.lr_scheduler_type, name=args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.

View File

@@ -526,8 +526,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
name=args.lr_scheduler_type, name=args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.

View File

@@ -563,8 +563,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
name=args.lr_scheduler_type, name=args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.

View File

@@ -510,8 +510,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
name=args.lr_scheduler_type, name=args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.

View File

@@ -750,8 +750,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
name=args.lr_scheduler_type, name=args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.

View File

@@ -780,8 +780,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
name=args.lr_scheduler_type, name=args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.

View File

@@ -513,8 +513,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
name=args.lr_scheduler_type, name=args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.

View File

@@ -580,8 +580,8 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
name=args.lr_scheduler_type, name=args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.