From f717d47fe0c64aa045b62f1298a028d246b16172 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 28 Jun 2022 08:55:02 +0200 Subject: [PATCH] Fix `test_number_of_steps_in_training_with_ipex` (#17889) Co-authored-by: ydshieh --- tests/trainer/test_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ce58d6adeb..9b1aa40016 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -649,14 +649,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): # Regular training has n_epochs * len(train_dl) steps trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, no_cuda=True) train_output = trainer.train() - self.assertEqual(train_output.global_step, self.n_epochs * 64 / self.batch_size) + self.assertEqual(train_output.global_step, self.n_epochs * 64 / trainer.args.train_batch_size) # Check passing num_train_epochs works (and a float version too): trainer = get_regression_trainer( learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, no_cuda=True ) train_output = trainer.train() - self.assertEqual(train_output.global_step, int(1.5 * 64 / self.batch_size)) + self.assertEqual(train_output.global_step, int(1.5 * 64 / trainer.args.train_batch_size)) # If we pass a max_steps, num_train_epochs is ignored trainer = get_regression_trainer(