From 83ef8bcac2f6ce00a3c6256a4ba747c8802480f6 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 18 Nov 2021 10:25:06 -0500 Subject: [PATCH] Fix finite IterableDataset test on multiple GPUs (#14445) --- tests/test_trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index ab280e2bb5..5b2029a299 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1069,13 +1069,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler) def test_training_finite_iterable_dataset(self): + num_gpus = max(1, get_gpu_count()) + if num_gpus > 2: + return + config = RegressionModelConfig() model = RegressionPreTrainedModel(config) batch_size = 1 num_samples = 10 - available_steps = num_samples // batch_size + available_steps = num_samples // (batch_size * num_gpus) data = FiniteIterableDataset(length=num_samples) train_args = TrainingArguments(