Fix finite IterableDataset test on multiple GPUs (#14445)
This commit is contained in:
@@ -1069,13 +1069,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
|
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
|
||||||
|
|
||||||
def test_training_finite_iterable_dataset(self):
|
def test_training_finite_iterable_dataset(self):
|
||||||
|
num_gpus = max(1, get_gpu_count())
|
||||||
|
if num_gpus > 2:
|
||||||
|
return
|
||||||
|
|
||||||
config = RegressionModelConfig()
|
config = RegressionModelConfig()
|
||||||
model = RegressionPreTrainedModel(config)
|
model = RegressionPreTrainedModel(config)
|
||||||
|
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
num_samples = 10
|
num_samples = 10
|
||||||
|
|
||||||
available_steps = num_samples // batch_size
|
available_steps = num_samples // (batch_size * num_gpus)
|
||||||
|
|
||||||
data = FiniteIterableDataset(length=num_samples)
|
data = FiniteIterableDataset(length=num_samples)
|
||||||
train_args = TrainingArguments(
|
train_args = TrainingArguments(
|
||||||
|
|||||||
Reference in New Issue
Block a user