Better support for resuming training (#8878)
This commit is contained in:
@@ -665,12 +665,12 @@ class Trainer:
|
||||
)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", num_examples)
|
||||
logger.info(" Num Epochs = %d", num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
|
||||
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", max_steps)
|
||||
logger.info(f" Num examples = {num_examples}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
|
||||
self.state.epoch = 0
|
||||
epochs_trained = 0
|
||||
@@ -680,13 +680,20 @@ class Trainer:
|
||||
if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
|
||||
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
|
||||
epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
||||
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
||||
steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps
|
||||
if not self.args.ignore_data_skip:
|
||||
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
||||
steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps
|
||||
else:
|
||||
steps_trained_in_current_epoch = 0
|
||||
|
||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||
logger.info(" Continuing training from global step %d", self.state.global_step)
|
||||
logger.info(" Will skip the first %d batches in the first epoch", steps_trained_in_current_epoch)
|
||||
logger.info(f" Continuing training from epoch {epochs_trained}")
|
||||
logger.info(f" Continuing training from global step {self.state.global_step}")
|
||||
if not self.args.ignore_data_skip:
|
||||
logger.info(
|
||||
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
|
||||
"batches in the first epoch."
|
||||
)
|
||||
|
||||
# Update the references
|
||||
self.callback_handler.model = self.model
|
||||
@@ -712,6 +719,13 @@ class Trainer:
|
||||
|
||||
self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
|
||||
|
||||
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
|
||||
if not self.args.ignore_data_skip:
|
||||
for epoch in range(epochs_trained):
|
||||
# We just need to begin an iteration to create the randomization of the sampler.
|
||||
for _ in train_dataloader:
|
||||
break
|
||||
|
||||
for epoch in range(epochs_trained, num_train_epochs):
|
||||
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
||||
train_dataloader.sampler.set_epoch(epoch)
|
||||
|
||||
Reference in New Issue
Block a user