Reload checkpoint (#7984)

* Fix checkpoint loading in Trainer

* Fix typo
This commit is contained in:
Sylvain Gugger
2020-10-22 15:48:52 -04:00
committed by GitHub
parent 467573ddde
commit 5ae935d233
3 changed files with 35 additions and 17 deletions

View File

@@ -628,18 +628,7 @@ class Trainer:
self.state.is_hyper_param_search = trial is not None
# Check if saved optimizer or scheduler states exist
if (
model_path is not None
and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
):
# Load in optimizer and scheduler states
self.optimizer.load_state_dict(
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
reissue_pt_warnings(caught_warnings)
self._load_optimizer_and_scheduler(model_path)
# Mixed precision training with apex (torch < 1.6)
model = self.model
@@ -919,6 +908,34 @@ class Trainer:
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True)
def _load_optimizer_and_scheduler(self, model_path):
"""If optimizer and scheduler states exist, load them."""
if (
model_path is not None
and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
):
# Load in optimizer and scheduler states
if is_torch_tpu_available():
# On TPU we have to take some extra precautions to properly load the states on the right device.
optimizer_state = torch.load(os.path.join(model_path, "optimizer.pt"), map_location="cpu")
with warnings.catch_warnings(record=True) as caught_warnings:
lr_scheduler_state = torch.load(os.path.join(model_path, "scheduler.pt"), map_location="cpu")
reissue_pt_warnings(caught_warnings)
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
self.optimizer.load_state_dict(optimizer_state)
self.lr_scheduler.load_state_dict(lr_scheduler_state)
else:
self.optimizer.load_state_dict(
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
reissue_pt_warnings(caught_warnings)
def hyperparameter_search(
self,
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,