Allow resume_from_checkpoint to handle auto_find_batch_size (#27568)
* Fuffill request * Add test * Better test * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Better test * Better test * MOre comments --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -38,6 +38,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
IntervalStrategy,
|
||||
PretrainedConfig,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
is_torch_available,
|
||||
@@ -1546,6 +1547,41 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_glue.main()
|
||||
|
||||
def test_auto_batch_size_with_resume_from_checkpoint(self):
|
||||
train_dataset = RegressionDataset(length=128)
|
||||
|
||||
config = RegressionModelConfig(a=0, b=2)
|
||||
model = RegressionRandomPreTrainedModel(config)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
class MockCudaOOMCallback(TrainerCallback):
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
# simulate OOM on the first step
|
||||
if state.train_batch_size == 16:
|
||||
raise RuntimeError("CUDA out of memory.")
|
||||
|
||||
args = RegressionTrainingArguments(
|
||||
tmp_dir,
|
||||
do_train=True,
|
||||
max_steps=2,
|
||||
save_steps=1,
|
||||
per_device_train_batch_size=16,
|
||||
auto_find_batch_size=True,
|
||||
)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
|
||||
trainer.train()
|
||||
# After `auto_find_batch_size` is ran we should now be at 8
|
||||
self.assertEqual(trainer._train_batch_size, 8)
|
||||
|
||||
# We can then make a new Trainer
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||
# Check we are at 16 to start
|
||||
self.assertEqual(trainer._train_batch_size, 16)
|
||||
trainer.train(resume_from_checkpoint=True)
|
||||
# We should be back to 8 again, picking up based upon the last ran Trainer
|
||||
self.assertEqual(trainer._train_batch_size, 8)
|
||||
|
||||
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
|
||||
def test_training_with_resume_from_checkpoint_false(self):
|
||||
train_dataset = RegressionDataset(length=128)
|
||||
|
||||
Reference in New Issue
Block a user