Fix finite IterableDataset test on multiple GPUs (#14445)

This commit is contained in:
Sylvain Gugger
2021-11-18 10:25:06 -05:00
committed by GitHub
parent da36c557f7
commit 83ef8bcac2

View File

@@ -1069,13 +1069,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
def test_training_finite_iterable_dataset(self):
num_gpus = max(1, get_gpu_count())
if num_gpus > 2:
return
config = RegressionModelConfig()
model = RegressionPreTrainedModel(config)
batch_size = 1
num_samples = 10
available_steps = num_samples // batch_size
available_steps = num_samples // (batch_size * num_gpus)
data = FiniteIterableDataset(length=num_samples)
train_args = TrainingArguments(