Reload checkpoint (#7984)
* Fix checkpoint loading in Trainer * Fix typo
This commit is contained in:
@@ -628,18 +628,7 @@ class Trainer:
|
||||
self.state.is_hyper_param_search = trial is not None
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if (
|
||||
model_path is not None
|
||||
and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
|
||||
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
self._load_optimizer_and_scheduler(model_path)
|
||||
|
||||
# Mixed precision training with apex (torch < 1.6)
|
||||
model = self.model
|
||||
@@ -919,6 +908,34 @@ class Trainer:
|
||||
if self.is_world_process_zero():
|
||||
self._rotate_checkpoints(use_mtime=True)
|
||||
|
||||
def _load_optimizer_and_scheduler(self, model_path):
|
||||
"""If optimizer and scheduler states exist, load them."""
|
||||
if (
|
||||
model_path is not None
|
||||
and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
|
||||
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
|
||||
):
|
||||
# Load in optimizer and scheduler states
|
||||
if is_torch_tpu_available():
|
||||
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
||||
optimizer_state = torch.load(os.path.join(model_path, "optimizer.pt"), map_location="cpu")
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
lr_scheduler_state = torch.load(os.path.join(model_path, "scheduler.pt"), map_location="cpu")
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
|
||||
xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
|
||||
|
||||
self.optimizer.load_state_dict(optimizer_state)
|
||||
self.lr_scheduler.load_state_dict(lr_scheduler_state)
|
||||
else:
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
def hyperparameter_search(
|
||||
self,
|
||||
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
|
||||
|
||||
Reference in New Issue
Block a user