Fix loading the best model on the last stage of training (#11718)
This commit is contained in:
@@ -1059,18 +1059,7 @@ class Trainer:
|
|||||||
# We load the model state dict on the CPU to avoid an OOM error.
|
# We load the model state dict on the CPU to avoid an OOM error.
|
||||||
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
|
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
|
||||||
# If the model is on the GPU, it still works!
|
# If the model is on the GPU, it still works!
|
||||||
load_result = self.model.load_state_dict(state_dict, strict=False)
|
self._load_state_dict_in_model(state_dict)
|
||||||
if len(load_result.missing_keys) != 0:
|
|
||||||
if load_result.missing_keys == self.model._keys_to_ignore_on_save:
|
|
||||||
self.model.tie_weights()
|
|
||||||
else:
|
|
||||||
logger.warn(
|
|
||||||
f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}."
|
|
||||||
)
|
|
||||||
if len(load_result.unexpected_keys) != 0:
|
|
||||||
logger.warn(
|
|
||||||
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||||
if model_reloaded:
|
if model_reloaded:
|
||||||
@@ -1363,7 +1352,7 @@ class Trainer:
|
|||||||
# We load the model state dict on the CPU to avoid an OOM error.
|
# We load the model state dict on the CPU to avoid an OOM error.
|
||||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME), map_location="cpu")
|
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME), map_location="cpu")
|
||||||
# If the model is on the GPU, it still works!
|
# If the model is on the GPU, it still works!
|
||||||
self.model.load_state_dict(state_dict)
|
self._load_state_dict_in_model(state_dict)
|
||||||
|
|
||||||
if self.deepspeed:
|
if self.deepspeed:
|
||||||
self.deepspeed.load_checkpoint(
|
self.deepspeed.load_checkpoint(
|
||||||
@@ -1385,6 +1374,17 @@ class Trainer:
|
|||||||
|
|
||||||
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics)
|
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics)
|
||||||
|
|
||||||
|
def _load_state_dict_in_model(self, state_dict):
|
||||||
|
load_result = self.model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
if len(load_result.missing_keys) != 0:
|
||||||
|
if set(load_result.missing_keys) == set(self.model._keys_to_ignore_on_save):
|
||||||
|
self.model.tie_weights()
|
||||||
|
else:
|
||||||
|
logger.warn(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
|
||||||
|
if len(load_result.unexpected_keys) != 0:
|
||||||
|
logger.warn(f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}.")
|
||||||
|
|
||||||
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
|
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
|
||||||
if self.control.should_log:
|
if self.control.should_log:
|
||||||
logs: Dict[str, float] = {}
|
logs: Dict[str, float] = {}
|
||||||
|
|||||||
@@ -180,7 +180,8 @@ class ModelTesterMixin:
|
|||||||
# Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
|
# Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
|
||||||
load_result = model.load_state_dict(state_dict_saved, strict=False)
|
load_result = model.load_state_dict(state_dict_saved, strict=False)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
len(load_result.missing_keys) == 0 or load_result.missing_keys == model._keys_to_ignore_on_save
|
len(load_result.missing_keys) == 0
|
||||||
|
or set(load_result.missing_keys) == set(model._keys_to_ignore_on_save)
|
||||||
)
|
)
|
||||||
self.assertTrue(len(load_result.unexpected_keys) == 0)
|
self.assertTrue(len(load_result.unexpected_keys) == 0)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user