From 7c10dd22ae4c94ad6afae3d22843e7203d3666de Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 1 Dec 2020 13:45:21 -0500 Subject: [PATCH] Better support for resuming training (#8878) --- src/transformers/trainer.py | 36 +++++++++++++++++++++---------- src/transformers/training_args.py | 10 +++++++++ tests/test_trainer.py | 30 ++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 11 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f50b96e455..b678b064c9 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -665,12 +665,12 @@ class Trainer: ) logger.info("***** Running training *****") - logger.info(" Num examples = %d", num_examples) - logger.info(" Num Epochs = %d", num_train_epochs) - logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) - logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) - logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) - logger.info(" Total optimization steps = %d", max_steps) + logger.info(f" Num examples = {num_examples}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps}") self.state.epoch = 0 epochs_trained = 0 @@ -680,13 +680,20 @@ class Trainer: if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")): self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json")) epochs_trained = self.state.global_step // num_update_steps_per_epoch - steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) - steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps + if not self.args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 logger.info(" Continuing training from checkpoint, will skip to saved global_step") - logger.info(" Continuing training from epoch %d", epochs_trained) - logger.info(" Continuing training from global step %d", self.state.global_step) - logger.info(" Will skip the first %d batches in the first epoch", steps_trained_in_current_epoch) + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {self.state.global_step}") + if not self.args.ignore_data_skip: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " + "batches in the first epoch." + ) # Update the references self.callback_handler.model = self.model @@ -712,6 +719,13 @@ class Trainer: self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control) + # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. + if not self.args.ignore_data_skip: + for epoch in range(epochs_trained): + # We just need to begin an iteration to create the randomization of the sampler. + for _ in train_dataloader: + break + for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 78990cc879..7142dcb3c2 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -189,6 +189,10 @@ class TrainingArguments: model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`): If there are more than one devices, whether to use model parallelism to distribute the model's modules across devices or not. + ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`): + When resuming training, whether or not to skip the epochs and batches to get the data loading at the same + stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping + step can take a long time) but will not yield the same results as the interrupted training would have. """ output_dir: str = field( @@ -350,6 +354,12 @@ class TrainingArguments: greater_is_better: Optional[bool] = field( default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."} ) + ignore_data_skip: bool = field( + default=False, + metadata={ + "help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data." + }, + ) def __post_init__(self): if self.disable_tqdm is None: diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 3a5916d19a..e6fd44c37c 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -554,6 +554,20 @@ class TrainerIntegrationTest(unittest.TestCase): self.assertEqual(b, b1) self.assertEqual(state, state1) + # Now check with a later checkpoint that it also works when we span over one epoch + checkpoint = os.path.join(tmpdir, "checkpoint-15") + + # Reinitialize trainer and load model + model = RegressionPreTrainedModel.from_pretrained(checkpoint) + trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset) + + trainer.train(model_path=checkpoint) + (a1, b1) = trainer.model.a.item(), trainer.model.b.item() + state1 = dataclasses.asdict(trainer.state) + self.assertEqual(a, a1) + self.assertEqual(b, b1) + self.assertEqual(state, state1) + # With a regular model that is not a PreTrainedModel with tempfile.TemporaryDirectory() as tmpdir: trainer = get_regression_trainer( @@ -578,6 +592,22 @@ class TrainerIntegrationTest(unittest.TestCase): self.assertEqual(b, b1) self.assertEqual(state, state1) + # Now check with a later checkpoint that it also works when we span over one epoch + checkpoint = os.path.join(tmpdir, "checkpoint-15") + + # Reinitialize trainer and load model + model = RegressionModel() + state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME)) + model.load_state_dict(state_dict) + trainer = Trainer(model, trainer.args, train_dataset=trainer.train_dataset) + + trainer.train(model_path=checkpoint) + (a1, b1) = trainer.model.a.item(), trainer.model.b.item() + state1 = dataclasses.asdict(trainer.state) + self.assertEqual(a, a1) + self.assertEqual(b, b1) + self.assertEqual(state, state1) + def test_resume_training_with_gradient_accumulation(self): if torch.cuda.device_count() > 2: # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of