Fix test for auto_find_batch_size on multi-GPU (#27947)
* Fix test for multi-GPU * WIth CPU handle
This commit is contained in:
@@ -1558,7 +1558,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
class MockCudaOOMCallback(TrainerCallback):
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
# simulate OOM on the first step
|
||||
if state.train_batch_size == 16:
|
||||
if state.train_batch_size >= 16:
|
||||
raise RuntimeError("CUDA out of memory.")
|
||||
|
||||
args = RegressionTrainingArguments(
|
||||
@@ -1577,7 +1577,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
# We can then make a new Trainer
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||
# Check we are at 16 to start
|
||||
self.assertEqual(trainer._train_batch_size, 16)
|
||||
self.assertEqual(trainer._train_batch_size, 16 * max(trainer.args.n_gpu, 1))
|
||||
trainer.train(resume_from_checkpoint=True)
|
||||
# We should be back to 8 again, picking up based upon the last ran Trainer
|
||||
self.assertEqual(trainer._train_batch_size, 8)
|
||||
|
||||
Reference in New Issue
Block a user