Fix Trainer with a parallel model (#9578)
* Fix Trainer with a parallel model * More clean up
This commit is contained in:
@@ -381,9 +381,11 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
# Make the Trainer believe it's a parallelized model
|
||||
model.is_parallelizable = True
|
||||
model.model_parallel = True
|
||||
trainer = Trainer(model=model, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
|
||||
args = TrainingArguments("./regression", per_device_train_batch_size=16, per_device_eval_batch_size=16)
|
||||
trainer = Trainer(model, args, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
|
||||
# Check the Trainer was fooled
|
||||
self.assertTrue(trainer.is_model_parallel)
|
||||
self.assertEqual(trainer.args.n_gpu, 1)
|
||||
|
||||
# The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu
|
||||
self.assertEqual(trainer.get_train_dataloader().batch_size, 16)
|
||||
|
||||
Reference in New Issue
Block a user