Loading from last checkpoint functionality in Trainer.train (#10334)
Enhance resume_from_checkpoint argument of Trainer.train to accept bool type. If True given, last saved checkpoint in self.args.output_dir will be loaded. (#10280)
This commit is contained in:
@@ -97,6 +97,7 @@ from .trainer_utils import (
|
|||||||
TrainOutput,
|
TrainOutput,
|
||||||
default_compute_objective,
|
default_compute_objective,
|
||||||
default_hp_space,
|
default_hp_space,
|
||||||
|
get_last_checkpoint,
|
||||||
set_seed,
|
set_seed,
|
||||||
speed_metrics,
|
speed_metrics,
|
||||||
)
|
)
|
||||||
@@ -758,7 +759,7 @@ class Trainer:
|
|||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
resume_from_checkpoint: Optional[str] = None,
|
resume_from_checkpoint: Optional[Union[str, bool]] = None,
|
||||||
trial: Union["optuna.Trial", Dict[str, Any]] = None,
|
trial: Union["optuna.Trial", Dict[str, Any]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -766,9 +767,11 @@ class Trainer:
|
|||||||
Main training entry point.
|
Main training entry point.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
resume_from_checkpoint (:obj:`str`, `optional`):
|
resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`):
|
||||||
Local path to a saved checkpoint as saved by a previous instance of :class:`~transformers.Trainer`. If
|
If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of
|
||||||
present, training will resume from the model/optimizer/scheduler states loaded here.
|
:class:`~transformers.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in
|
||||||
|
`args.output_dir` as saved by a previous instance of :class:`~transformers.Trainer`. If present,
|
||||||
|
training will resume from the model/optimizer/scheduler states loaded here.
|
||||||
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
|
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
|
||||||
The trial run or the hyperparameter dictionary for hyperparameter search.
|
The trial run or the hyperparameter dictionary for hyperparameter search.
|
||||||
kwargs:
|
kwargs:
|
||||||
@@ -803,6 +806,11 @@ class Trainer:
|
|||||||
self.optimizer, self.lr_scheduler = None, None
|
self.optimizer, self.lr_scheduler = None, None
|
||||||
|
|
||||||
# Load potential model checkpoint
|
# Load potential model checkpoint
|
||||||
|
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
|
||||||
|
resume_from_checkpoint = get_last_checkpoint(self.args.output_dir)
|
||||||
|
if resume_from_checkpoint is None:
|
||||||
|
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
|
||||||
|
|
||||||
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
|
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
|
||||||
logger.info(f"Loading model from {resume_from_checkpoint}).")
|
logger.info(f"Loading model from {resume_from_checkpoint}).")
|
||||||
if isinstance(self.model, PreTrainedModel):
|
if isinstance(self.model, PreTrainedModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user