From 94d8767ba3c036a37e8e1ea4e5ae3695b92eebe0 Mon Sep 17 00:00:00 2001 From: Tanmay Garg Date: Tue, 23 Feb 2021 02:03:00 +0530 Subject: [PATCH] Loading from last checkpoint functionality in Trainer.train (#10334) Enhance resume_from_checkpoint argument of Trainer.train to accept bool type. If True given, last saved checkpoint in self.args.output_dir will be loaded. (#10280) --- src/transformers/trainer.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 02837d3eeb..dc7a15d625 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -97,6 +97,7 @@ from .trainer_utils import ( TrainOutput, default_compute_objective, default_hp_space, + get_last_checkpoint, set_seed, speed_metrics, ) @@ -758,7 +759,7 @@ class Trainer: def train( self, - resume_from_checkpoint: Optional[str] = None, + resume_from_checkpoint: Optional[Union[str, bool]] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None, **kwargs, ): @@ -766,9 +767,11 @@ class Trainer: Main training entry point. Args: - resume_from_checkpoint (:obj:`str`, `optional`): - Local path to a saved checkpoint as saved by a previous instance of :class:`~transformers.Trainer`. If - present, training will resume from the model/optimizer/scheduler states loaded here. + resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`): + If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of + :class:`~transformers.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in + `args.output_dir` as saved by a previous instance of :class:`~transformers.Trainer`. If present, + training will resume from the model/optimizer/scheduler states loaded here. trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`): The trial run or the hyperparameter dictionary for hyperparameter search. kwargs: @@ -803,6 +806,11 @@ class Trainer: self.optimizer, self.lr_scheduler = None, None # Load potential model checkpoint + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: + resume_from_checkpoint = get_last_checkpoint(self.args.output_dir) + if resume_from_checkpoint is None: + raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})") + if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): logger.info(f"Loading model from {resume_from_checkpoint}).") if isinstance(self.model, PreTrainedModel):