From 5ae935d233dc1bab7ca45303a97d1c808a83afb5 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 22 Oct 2020 15:48:52 -0400 Subject: [PATCH] Reload checkpoint (#7984) * Fix checkpoint loading in Trainer * Fix typo --- src/transformers/trainer.py | 41 ++++++++++++++++++++-------- src/transformers/trainer_callback.py | 4 ++- src/transformers/trainer_pt_utils.py | 7 ++--- 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 21a452e2bd..400527a225 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -628,18 +628,7 @@ class Trainer: self.state.is_hyper_param_search = trial is not None # Check if saved optimizer or scheduler states exist - if ( - model_path is not None - and os.path.isfile(os.path.join(model_path, "optimizer.pt")) - and os.path.isfile(os.path.join(model_path, "scheduler.pt")) - ): - # Load in optimizer and scheduler states - self.optimizer.load_state_dict( - torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) - ) - with warnings.catch_warnings(record=True) as caught_warnings: - self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) - reissue_pt_warnings(caught_warnings) + self._load_optimizer_and_scheduler(model_path) # Mixed precision training with apex (torch < 1.6) model = self.model @@ -919,6 +908,34 @@ class Trainer: if self.is_world_process_zero(): self._rotate_checkpoints(use_mtime=True) + def _load_optimizer_and_scheduler(self, model_path): + """If optimizer and scheduler states exist, load them.""" + if ( + model_path is not None + and os.path.isfile(os.path.join(model_path, "optimizer.pt")) + and os.path.isfile(os.path.join(model_path, "scheduler.pt")) + ): + # Load in optimizer and scheduler states + if is_torch_tpu_available(): + # On TPU we have to take some extra precautions to properly load the states on the right device. + optimizer_state = torch.load(os.path.join(model_path, "optimizer.pt"), map_location="cpu") + with warnings.catch_warnings(record=True) as caught_warnings: + lr_scheduler_state = torch.load(os.path.join(model_path, "scheduler.pt"), map_location="cpu") + reissue_pt_warnings(caught_warnings) + + xm.send_cpu_data_to_device(optimizer_state, self.args.device) + xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) + + self.optimizer.load_state_dict(optimizer_state) + self.lr_scheduler.load_state_dict(lr_scheduler_state) + else: + self.optimizer.load_state_dict( + torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) + ) + with warnings.catch_warnings(record=True) as caught_warnings: + self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) + reissue_pt_warnings(caught_warnings) + def hyperparameter_search( self, hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 2222362618..0a654caa4d 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -436,10 +436,12 @@ class ProgressCallback(TrainerCallback): def on_train_begin(self, args, state, control, **kwargs): if state.is_local_process_zero: self.training_bar = tqdm(total=state.max_steps) + self.current_step = 0 def on_step_end(self, args, state, control, **kwargs): if state.is_local_process_zero: - self.training_bar.update(1) + self.training_bar.update(state.global_step - self.current_step) + self.current_step = state.global_step def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): if state.is_local_process_zero: diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 45ff9c8fdf..aab0e162f8 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -23,6 +23,7 @@ from typing import List, Optional, Union import numpy as np import torch +from torch.optim.lr_scheduler import SAVE_STATE_WARNING from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, Sampler @@ -33,8 +34,6 @@ from .utils import logging if is_torch_tpu_available(): import torch_xla.core.xla_model as xm -PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler." - logger = logging.get_logger(__name__) @@ -112,10 +111,10 @@ def distributed_broadcast_scalars( def reissue_pt_warnings(caught_warnings): - # Reissue warnings that are not the PT_LR_SCHEDULER_WARNING + # Reissue warnings that are not the SAVE_STATE_WARNING if len(caught_warnings) > 1: for w in caught_warnings: - if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING: + if w.category != UserWarning or w.message != SAVE_STATE_WARNING: warnings.warn(w.message, w.category)