Fix test for auto_find_batch_size on multi-GPU (#27947)
* Fix test for multi-GPU * WIth CPU handle
This commit is contained in:
@@ -1558,7 +1558,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
class MockCudaOOMCallback(TrainerCallback):
|
class MockCudaOOMCallback(TrainerCallback):
|
||||||
def on_step_end(self, args, state, control, **kwargs):
|
def on_step_end(self, args, state, control, **kwargs):
|
||||||
# simulate OOM on the first step
|
# simulate OOM on the first step
|
||||||
if state.train_batch_size == 16:
|
if state.train_batch_size >= 16:
|
||||||
raise RuntimeError("CUDA out of memory.")
|
raise RuntimeError("CUDA out of memory.")
|
||||||
|
|
||||||
args = RegressionTrainingArguments(
|
args = RegressionTrainingArguments(
|
||||||
@@ -1577,7 +1577,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
# We can then make a new Trainer
|
# We can then make a new Trainer
|
||||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||||
# Check we are at 16 to start
|
# Check we are at 16 to start
|
||||||
self.assertEqual(trainer._train_batch_size, 16)
|
self.assertEqual(trainer._train_batch_size, 16 * max(trainer.args.n_gpu, 1))
|
||||||
trainer.train(resume_from_checkpoint=True)
|
trainer.train(resume_from_checkpoint=True)
|
||||||
# We should be back to 8 again, picking up based upon the last ran Trainer
|
# We should be back to 8 again, picking up based upon the last ran Trainer
|
||||||
self.assertEqual(trainer._train_batch_size, 8)
|
self.assertEqual(trainer._train_batch_size, 8)
|
||||||
|
|||||||
Reference in New Issue
Block a user