From 3f6973db06d0149ee94a71a8f7cf4c374c675cd4 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Fri, 8 Mar 2024 23:52:25 +0800 Subject: [PATCH] [tests] use the correct `n_gpu` in `TrainerIntegrationTest::test_train_and_eval_dataloaders` for XPU (#29307) * fix n_gpu * fix style --- tests/trainer/test_trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 98f3c96b4e..bd704bc8b5 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1029,7 +1029,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertFalse(is_any_loss_nan_or_inf(log_history_filter)) def test_train_and_eval_dataloaders(self): - n_gpu = max(1, backend_device_count(torch_device)) + if torch_device == "cuda": + n_gpu = max(1, backend_device_count(torch_device)) + else: + n_gpu = 1 trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16) self.assertEqual(trainer.get_train_dataloader().total_batch_size, 16 * n_gpu) trainer = get_regression_trainer(learning_rate=0.1, per_device_eval_batch_size=16)