Fix test for auto_find_batch_size on multi-GPU (#27947)

* Fix test for multi-GPU

* WIth CPU handle
This commit is contained in:
Zach Mueller
2023-12-11 09:57:41 -05:00
committed by GitHub
parent b911c1f10f
commit 44127ec667

View File

@@ -1558,7 +1558,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
class MockCudaOOMCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
# simulate OOM on the first step
if state.train_batch_size == 16:
if state.train_batch_size >= 16:
raise RuntimeError("CUDA out of memory.")
args = RegressionTrainingArguments(
@@ -1577,7 +1577,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# We can then make a new Trainer
trainer = Trainer(model, args, train_dataset=train_dataset)
# Check we are at 16 to start
self.assertEqual(trainer._train_batch_size, 16)
self.assertEqual(trainer._train_batch_size, 16 * max(trainer.args.n_gpu, 1))
trainer.train(resume_from_checkpoint=True)
# We should be back to 8 again, picking up based upon the last ran Trainer
self.assertEqual(trainer._train_batch_size, 8)