From 0410a29a2d5c798b2c0c1ca28398e0ddcf3384f2 Mon Sep 17 00:00:00 2001 From: Phuc Van Phan Date: Thu, 7 Dec 2023 13:59:26 +0700 Subject: [PATCH] fix: fix gradient accumulate step for learning rate (#27667) --- .../speech-recognition/run_flax_speech_recognition_seq2seq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py index 9f602e1c85..44721d6f41 100644 --- a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py +++ b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py @@ -640,7 +640,7 @@ def main(): # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( - len(vectorized_datasets["train"]), + total_train_steps, training_args.warmup_steps, training_args.learning_rate, )