Allow resume_from_checkpoint to handle auto_find_batch_size (#27568)
* Fuffill request * Add test * Better test * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Better test * Better test * MOre comments --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -1507,6 +1507,10 @@ class Trainer:
|
|||||||
and not self.is_fsdp_enabled
|
and not self.is_fsdp_enabled
|
||||||
):
|
):
|
||||||
self._load_from_checkpoint(resume_from_checkpoint)
|
self._load_from_checkpoint(resume_from_checkpoint)
|
||||||
|
# In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
|
||||||
|
state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
|
||||||
|
if state.train_batch_size is not None:
|
||||||
|
self._train_batch_size = state.train_batch_size
|
||||||
|
|
||||||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||||
if model_reloaded:
|
if model_reloaded:
|
||||||
@@ -1542,6 +1546,8 @@ class Trainer:
|
|||||||
):
|
):
|
||||||
self.accelerator.free_memory()
|
self.accelerator.free_memory()
|
||||||
self._train_batch_size = batch_size
|
self._train_batch_size = batch_size
|
||||||
|
if self.args.auto_find_batch_size:
|
||||||
|
self.state.train_batch_size = self._train_batch_size
|
||||||
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
|
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
|
||||||
# Data loader and number of training steps
|
# Data loader and number of training steps
|
||||||
train_dataloader = self.get_train_dataloader()
|
train_dataloader = self.get_train_dataloader()
|
||||||
@@ -1618,6 +1624,7 @@ class Trainer:
|
|||||||
|
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
self.state.is_hyper_param_search = trial is not None
|
self.state.is_hyper_param_search = trial is not None
|
||||||
|
self.state.train_batch_size = self._train_batch_size
|
||||||
|
|
||||||
# Compute absolute values for logging, eval, and save if given as ratio
|
# Compute absolute values for logging, eval, and save if given as ratio
|
||||||
if args.logging_steps is not None:
|
if args.logging_steps is not None:
|
||||||
|
|||||||
@@ -59,6 +59,9 @@ class TrainerState:
|
|||||||
Run an evaluation every X steps.
|
Run an evaluation every X steps.
|
||||||
save_steps (`int`, *optional*, defaults to 500):
|
save_steps (`int`, *optional*, defaults to 500):
|
||||||
Save checkpoint every X updates steps.
|
Save checkpoint every X updates steps.
|
||||||
|
train_batch_size (`int`, *optional*):
|
||||||
|
The batch size for the training dataloader. Only needed when
|
||||||
|
`auto_find_batch_size` has been used.
|
||||||
num_input_tokens_seen (`int`, *optional*, defaults to 0):
|
num_input_tokens_seen (`int`, *optional*, defaults to 0):
|
||||||
The number of tokens seen during training (number of input tokens, not the number of prediction tokens).
|
The number of tokens seen during training (number of input tokens, not the number of prediction tokens).
|
||||||
total_flos (`float`, *optional*, defaults to 0):
|
total_flos (`float`, *optional*, defaults to 0):
|
||||||
@@ -88,6 +91,7 @@ class TrainerState:
|
|||||||
logging_steps: int = 500
|
logging_steps: int = 500
|
||||||
eval_steps: int = 500
|
eval_steps: int = 500
|
||||||
save_steps: int = 500
|
save_steps: int = 500
|
||||||
|
train_batch_size: int = None
|
||||||
num_train_epochs: int = 0
|
num_train_epochs: int = 0
|
||||||
num_input_tokens_seen: int = 0
|
num_input_tokens_seen: int = 0
|
||||||
total_flos: float = 0
|
total_flos: float = 0
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
IntervalStrategy,
|
IntervalStrategy,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
|
TrainerCallback,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
get_polynomial_decay_schedule_with_warmup,
|
get_polynomial_decay_schedule_with_warmup,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -1546,6 +1547,41 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
run_glue.main()
|
run_glue.main()
|
||||||
|
|
||||||
|
def test_auto_batch_size_with_resume_from_checkpoint(self):
|
||||||
|
train_dataset = RegressionDataset(length=128)
|
||||||
|
|
||||||
|
config = RegressionModelConfig(a=0, b=2)
|
||||||
|
model = RegressionRandomPreTrainedModel(config)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
|
||||||
|
class MockCudaOOMCallback(TrainerCallback):
|
||||||
|
def on_step_end(self, args, state, control, **kwargs):
|
||||||
|
# simulate OOM on the first step
|
||||||
|
if state.train_batch_size == 16:
|
||||||
|
raise RuntimeError("CUDA out of memory.")
|
||||||
|
|
||||||
|
args = RegressionTrainingArguments(
|
||||||
|
tmp_dir,
|
||||||
|
do_train=True,
|
||||||
|
max_steps=2,
|
||||||
|
save_steps=1,
|
||||||
|
per_device_train_batch_size=16,
|
||||||
|
auto_find_batch_size=True,
|
||||||
|
)
|
||||||
|
trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
|
||||||
|
trainer.train()
|
||||||
|
# After `auto_find_batch_size` is ran we should now be at 8
|
||||||
|
self.assertEqual(trainer._train_batch_size, 8)
|
||||||
|
|
||||||
|
# We can then make a new Trainer
|
||||||
|
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||||
|
# Check we are at 16 to start
|
||||||
|
self.assertEqual(trainer._train_batch_size, 16)
|
||||||
|
trainer.train(resume_from_checkpoint=True)
|
||||||
|
# We should be back to 8 again, picking up based upon the last ran Trainer
|
||||||
|
self.assertEqual(trainer._train_batch_size, 8)
|
||||||
|
|
||||||
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
|
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
|
||||||
def test_training_with_resume_from_checkpoint_false(self):
|
def test_training_with_resume_from_checkpoint_false(self):
|
||||||
train_dataset = RegressionDataset(length=128)
|
train_dataset = RegressionDataset(length=128)
|
||||||
|
|||||||
Reference in New Issue
Block a user