diff --git a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py index 9f602e1c85..44721d6f41 100644 --- a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py +++ b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py @@ -640,7 +640,7 @@ def main(): # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( - len(vectorized_datasets["train"]), + total_train_steps, training_args.warmup_steps, training_args.learning_rate, )