Loading from last checkpoint functionality in Trainer.train (#10334)

Enhance resume_from_checkpoint argument of Trainer.train to accept
bool type. If True given, last saved checkpoint in self.args.output_dir
will be loaded. (#10280)
This commit is contained in:
Tanmay Garg
2021-02-23 02:03:00 +05:30
committed by GitHub
parent eab0afc19c
commit 94d8767ba3

View File

@@ -97,6 +97,7 @@ from .trainer_utils import (
TrainOutput, TrainOutput,
default_compute_objective, default_compute_objective,
default_hp_space, default_hp_space,
get_last_checkpoint,
set_seed, set_seed,
speed_metrics, speed_metrics,
) )
@@ -758,7 +759,7 @@ class Trainer:
def train( def train(
self, self,
resume_from_checkpoint: Optional[str] = None, resume_from_checkpoint: Optional[Union[str, bool]] = None,
trial: Union["optuna.Trial", Dict[str, Any]] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None,
**kwargs, **kwargs,
): ):
@@ -766,9 +767,11 @@ class Trainer:
Main training entry point. Main training entry point.
Args: Args:
resume_from_checkpoint (:obj:`str`, `optional`): resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`):
Local path to a saved checkpoint as saved by a previous instance of :class:`~transformers.Trainer`. If If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of
present, training will resume from the model/optimizer/scheduler states loaded here. :class:`~transformers.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in
`args.output_dir` as saved by a previous instance of :class:`~transformers.Trainer`. If present,
training will resume from the model/optimizer/scheduler states loaded here.
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`): trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
The trial run or the hyperparameter dictionary for hyperparameter search. The trial run or the hyperparameter dictionary for hyperparameter search.
kwargs: kwargs:
@@ -803,6 +806,11 @@ class Trainer:
self.optimizer, self.lr_scheduler = None, None self.optimizer, self.lr_scheduler = None, None
# Load potential model checkpoint # Load potential model checkpoint
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
resume_from_checkpoint = get_last_checkpoint(self.args.output_dir)
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
logger.info(f"Loading model from {resume_from_checkpoint}).") logger.info(f"Loading model from {resume_from_checkpoint}).")
if isinstance(self.model, PreTrainedModel): if isinstance(self.model, PreTrainedModel):