From a33168aa7894d5665eacf891ffdcb433cde0f38c Mon Sep 17 00:00:00 2001 From: Valentin Date: Tue, 16 Nov 2021 22:50:04 +0100 Subject: [PATCH] Avoid looping when data exhausted (#14413) * stop training when a finite IterableDataset is exhausted when using an iterable dataset num_epochs is set to sys.maxsize to make sure all data is consumed likewise we want to set max_steps high enough but still stop when all data is consumed (cherry picked from commit 6f0e1d6363153da9051e93acffe1cbab3a3f3b12) * fix typo flase -> false * add test for stopping training on exhausted finite iterable dataset * remove redundant gradient_accumulation_steps * run make style reformat training_args docstring --- src/transformers/trainer.py | 8 ++++++++ src/transformers/training_args.py | 3 ++- tests/test_trainer.py | 32 ++++++++++++++++++++++++++++++- 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a39ce6bbfd..f954fe3ae0 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1287,6 +1287,7 @@ class Trainer: ) self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + step = -1 for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training @@ -1386,6 +1387,13 @@ class Trainer: if self.control.should_epoch_stop or self.control.should_training_stop: break + if step < 0: + logger.warning( + f"There seems to be not a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index fa6d8b71ce..644f83665e 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -141,7 +141,8 @@ class TrainingArguments: the last epoch before stopping training). max_steps (:obj:`int`, `optional`, defaults to -1): If set to a positive number, the total number of training steps to perform. Overrides - :obj:`num_train_epochs`. + :obj:`num_train_epochs`. In case of using a finite iterable dataset the training may stop before reaching + the set number of steps when all data is exhausted lr_scheduler_type (:obj:`str` or :class:`~transformers.SchedulerType`, `optional`, defaults to :obj:`"linear"`): The scheduler type to use. See the documentation of :class:`~transformers.SchedulerType` for all possible values. diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 0f1b6a5ff6..ab280e2bb5 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -172,6 +172,16 @@ if is_torch_available(): for i in range(len(self.dataset)): yield self.dataset[i] + class FiniteIterableDataset(SampleIterableDataset): + def __init__(self, a=2, b=3, length=64, seed=42, label_names=None): + super().__init__(a, b, length, seed, label_names) + self.current_sample = 0 + + def __iter__(self): + while self.current_sample < len(self.dataset): + yield self.dataset[self.current_sample] + self.current_sample += 1 + class RegressionModel(nn.Module): def __init__(self, a=0, b=0, double_output=False): super().__init__() @@ -856,7 +866,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertAlmostEqual(b, b1, delta=1e-8) # regression for this issue: https://github.com/huggingface/transformers/issues/12970 - def test_training_with_resume_from_checkpoint_flase(self): + def test_training_with_resume_from_checkpoint_false(self): train_dataset = RegressionDataset(length=128) eval_dataset = RegressionDataset() @@ -1058,6 +1068,26 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertIsInstance(loader, torch.utils.data.DataLoader) self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler) + def test_training_finite_iterable_dataset(self): + config = RegressionModelConfig() + model = RegressionPreTrainedModel(config) + + batch_size = 1 + num_samples = 10 + + available_steps = num_samples // batch_size + + data = FiniteIterableDataset(length=num_samples) + train_args = TrainingArguments( + ".", + max_steps=available_steps + 1, # set a higher number than actually available + per_device_train_batch_size=batch_size, + ) + trainer = Trainer(model, train_dataset=data, args=train_args) + with self.assertLogs("transformers.trainer", level="WARNING") as logs: + trainer.train() + self.assertIn(f"stopping training at step {available_steps}!", logs.output[0]) + def test_evaluation_iterable_dataset(self): config = RegressionModelConfig(a=1.5, b=2.5) model = RegressionPreTrainedModel(config)