diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 98f3c96b4e..bd704bc8b5 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1029,7 +1029,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertFalse(is_any_loss_nan_or_inf(log_history_filter)) def test_train_and_eval_dataloaders(self): - n_gpu = max(1, backend_device_count(torch_device)) + if torch_device == "cuda": + n_gpu = max(1, backend_device_count(torch_device)) + else: + n_gpu = 1 trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16) self.assertEqual(trainer.get_train_dataloader().total_batch_size, 16 * n_gpu) trainer = get_regression_trainer(learning_rate=0.1, per_device_eval_batch_size=16)