Allow resume_from_checkpoint to handle auto_find_batch_size (#27568)

* Fuffill request

* Add test

* Better test

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Better test

* Better test

* MOre comments

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Zach Mueller
2023-12-08 11:51:02 -05:00
committed by GitHub
parent aa7ab98e72
commit 6757ed28ce
3 changed files with 47 additions and 0 deletions

View File

@@ -1507,6 +1507,10 @@ class Trainer:
and not self.is_fsdp_enabled and not self.is_fsdp_enabled
): ):
self._load_from_checkpoint(resume_from_checkpoint) self._load_from_checkpoint(resume_from_checkpoint)
# In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
if state.train_batch_size is not None:
self._train_batch_size = state.train_batch_size
# If model was re-initialized, put it on the right device and update self.model_wrapped # If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded: if model_reloaded:
@@ -1542,6 +1546,8 @@ class Trainer:
): ):
self.accelerator.free_memory() self.accelerator.free_memory()
self._train_batch_size = batch_size self._train_batch_size = batch_size
if self.args.auto_find_batch_size:
self.state.train_batch_size = self._train_batch_size
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
# Data loader and number of training steps # Data loader and number of training steps
train_dataloader = self.get_train_dataloader() train_dataloader = self.get_train_dataloader()
@@ -1618,6 +1624,7 @@ class Trainer:
self.state = TrainerState() self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None self.state.is_hyper_param_search = trial is not None
self.state.train_batch_size = self._train_batch_size
# Compute absolute values for logging, eval, and save if given as ratio # Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None: if args.logging_steps is not None:

View File

@@ -59,6 +59,9 @@ class TrainerState:
Run an evaluation every X steps. Run an evaluation every X steps.
save_steps (`int`, *optional*, defaults to 500): save_steps (`int`, *optional*, defaults to 500):
Save checkpoint every X updates steps. Save checkpoint every X updates steps.
train_batch_size (`int`, *optional*):
The batch size for the training dataloader. Only needed when
`auto_find_batch_size` has been used.
num_input_tokens_seen (`int`, *optional*, defaults to 0): num_input_tokens_seen (`int`, *optional*, defaults to 0):
The number of tokens seen during training (number of input tokens, not the number of prediction tokens). The number of tokens seen during training (number of input tokens, not the number of prediction tokens).
total_flos (`float`, *optional*, defaults to 0): total_flos (`float`, *optional*, defaults to 0):
@@ -88,6 +91,7 @@ class TrainerState:
logging_steps: int = 500 logging_steps: int = 500
eval_steps: int = 500 eval_steps: int = 500
save_steps: int = 500 save_steps: int = 500
train_batch_size: int = None
num_train_epochs: int = 0 num_train_epochs: int = 0
num_input_tokens_seen: int = 0 num_input_tokens_seen: int = 0
total_flos: float = 0 total_flos: float = 0

View File

@@ -38,6 +38,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
IntervalStrategy, IntervalStrategy,
PretrainedConfig, PretrainedConfig,
TrainerCallback,
TrainingArguments, TrainingArguments,
get_polynomial_decay_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup,
is_torch_available, is_torch_available,
@@ -1546,6 +1547,41 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_glue.main() run_glue.main()
def test_auto_batch_size_with_resume_from_checkpoint(self):
train_dataset = RegressionDataset(length=128)
config = RegressionModelConfig(a=0, b=2)
model = RegressionRandomPreTrainedModel(config)
tmp_dir = self.get_auto_remove_tmp_dir()
class MockCudaOOMCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
# simulate OOM on the first step
if state.train_batch_size == 16:
raise RuntimeError("CUDA out of memory.")
args = RegressionTrainingArguments(
tmp_dir,
do_train=True,
max_steps=2,
save_steps=1,
per_device_train_batch_size=16,
auto_find_batch_size=True,
)
trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
trainer.train()
# After `auto_find_batch_size` is ran we should now be at 8
self.assertEqual(trainer._train_batch_size, 8)
# We can then make a new Trainer
trainer = Trainer(model, args, train_dataset=train_dataset)
# Check we are at 16 to start
self.assertEqual(trainer._train_batch_size, 16)
trainer.train(resume_from_checkpoint=True)
# We should be back to 8 again, picking up based upon the last ran Trainer
self.assertEqual(trainer._train_batch_size, 8)
# regression for this issue: https://github.com/huggingface/transformers/issues/12970 # regression for this issue: https://github.com/huggingface/transformers/issues/12970
def test_training_with_resume_from_checkpoint_false(self): def test_training_with_resume_from_checkpoint_false(self):
train_dataset = RegressionDataset(length=128) train_dataset = RegressionDataset(length=128)