fix: fix gradient accumulate step for learning rate (#27667)
This commit is contained in:
@@ -640,7 +640,7 @@ def main():
|
|||||||
|
|
||||||
# Create learning rate schedule
|
# Create learning rate schedule
|
||||||
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||||||
len(vectorized_datasets["train"]),
|
total_train_steps,
|
||||||
training_args.warmup_steps,
|
training_args.warmup_steps,
|
||||||
training_args.learning_rate,
|
training_args.learning_rate,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user