Fix Trainer with a parallel model (#9578)

* Fix Trainer with a parallel model

* More clean up
This commit is contained in:
Sylvain Gugger
2021-01-14 03:23:41 -05:00
committed by GitHub
parent 126fd281bc
commit 5e1bea4f16
2 changed files with 14 additions and 13 deletions

View File

@@ -381,9 +381,11 @@ class TrainerIntegrationTest(unittest.TestCase):
# Make the Trainer believe it's a parallelized model
model.is_parallelizable = True
model.model_parallel = True
trainer = Trainer(model=model, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
args = TrainingArguments("./regression", per_device_train_batch_size=16, per_device_eval_batch_size=16)
trainer = Trainer(model, args, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
# Check the Trainer was fooled
self.assertTrue(trainer.is_model_parallel)
self.assertEqual(trainer.args.n_gpu, 1)
# The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu
self.assertEqual(trainer.get_train_dataloader().batch_size, 16)