Support DeepSpeed when using auto find batch size (#28088)

Fixup test
This commit is contained in:
Zach Mueller
2024-01-10 06:03:13 -05:00
committed by GitHub
parent a777f52599
commit 6015d0ad6c
3 changed files with 84 additions and 17 deletions

View File

@@ -57,6 +57,7 @@ from transformers.testing_utils import (
get_tests_dir,
is_staging_test,
require_accelerate,
require_deepspeed,
require_intel_extension_for_pytorch,
require_optuna,
require_ray,
@@ -1551,6 +1552,51 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
with patch.object(sys, "argv", testargs):
run_glue.main()
@require_deepspeed
def test_auto_batch_size_with_resume_from_checkpoint_with_deepspeed(self):
train_dataset = RegressionDataset(length=128)
config = RegressionModelConfig(a=0, b=2)
model = RegressionRandomPreTrainedModel(config)
tmp_dir = self.get_auto_remove_tmp_dir()
class MockCudaOOMCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
# simulate OOM on the first step
if state.train_batch_size >= 16:
raise RuntimeError("CUDA out of memory.")
deepspeed = {
"zero_optimization": {
"stage": 1,
},
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
}
args = RegressionTrainingArguments(
tmp_dir,
do_train=True,
max_steps=2,
save_steps=1,
per_device_train_batch_size=16,
auto_find_batch_size=True,
deepspeed=deepspeed,
)
trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
trainer.train()
# After `auto_find_batch_size` is ran we should now be at 8
self.assertEqual(trainer._train_batch_size, 8)
# We can then make a new Trainer
trainer = Trainer(model, args, train_dataset=train_dataset)
# Check we are at 16 to start
self.assertEqual(trainer._train_batch_size, 16 * max(trainer.args.n_gpu, 1))
trainer.train(resume_from_checkpoint=True)
# We should be back to 8 again, picking up based upon the last ran Trainer
self.assertEqual(trainer._train_batch_size, 8)
def test_auto_batch_size_with_resume_from_checkpoint(self):
train_dataset = RegressionDataset(length=128)