Add automatic best model loading to Trainer (#7431)

* Add automatic best model loading to Trainer

* Some small fixes

* Formatting
This commit is contained in:
Sylvain Gugger
2020-09-29 10:41:18 -04:00
committed by GitHub
parent 1fc4de69ed
commit 52e8392b7e
4 changed files with 313 additions and 47 deletions

View File

@@ -20,7 +20,7 @@ from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from tqdm.auto import tqdm, trange
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import is_datasets_available, is_torch_tpu_available
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
from .integrations import (
default_hp_search_backend,
is_comet_available,
@@ -42,6 +42,7 @@ from .trainer_utils import (
EvaluationStrategy,
HPSearchBackend,
PredictionOutput,
TrainerState,
TrainOutput,
default_compute_objective,
default_hp_space,
@@ -642,6 +643,7 @@ class Trainer:
self.args.max_steps = t_total
self.create_optimizer_and_scheduler(num_training_steps=t_total)
self.state = TrainerState()
# Check if saved optimizer or scheduler states exist
if (
@@ -657,6 +659,10 @@ class Trainer:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
reissue_pt_warnings(caught_warnings)
# Check if a saved Trainer state exist
if model_path is not None and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
model = self.model
if self.args.fp16 and _use_apex:
if not is_apex_available():
@@ -803,44 +809,15 @@ class Trainer:
):
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
if self.args.load_best_model_at_end:
self._save_training(model, trial, metrics=metrics)
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
# In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save.
if hasattr(model, "module"):
assert (
model.module is self.model
), f"Module {model.module} should be a reference to self.model"
else:
assert model is self.model, f"Model {model} should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
if self.hp_search_backend is not None and trial is not None:
run_id = (
trial.number
if self.hp_search_backend == HPSearchBackend.OPTUNA
else tune.get_trial_id()
)
checkpoint_folder += f"-run-{run_id}"
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
self.store_flos()
self.save_model(output_dir)
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True)
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
elif self.is_world_process_zero():
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
if (
not self.args.load_best_model_at_end
and self.args.save_steps > 0
and self.global_step % self.args.save_steps == 0
):
self._save_training(model, trial)
epoch_pbar.update(1)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
@@ -851,6 +828,8 @@ class Trainer:
if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
if self.args.load_best_model_at_end:
self._save_training(model, trial, metrics=metrics)
if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available():
@@ -872,8 +851,73 @@ class Trainer:
delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)
if isinstance(model, PreTrainedModel):
self.model = model.from_pretrained(self.state.best_model_checkpoint)
self.model = self.model.to(self.args.device)
else:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
def _save_training(self, model, trial, metrics=None):
# In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save.
if hasattr(model, "module"):
assert model.module is self.model, f"Module {model.module} should be a reference to self.model"
else:
assert model is self.model, f"Model {model} should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
if self.hp_search_backend is not None and trial is not None:
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
checkpoint_folder += f"-run-{run_id}"
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
self.store_flos()
self.save_model(output_dir)
# Save optimizer and scheduler
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
elif self.is_world_process_zero():
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
# Determine the new best metric / best model checkpoint
if metrics is not None:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics[metric_to_check]
operator = np.greater if self.args.greater_is_better else np.less
if (
self.state.best_metric is None
or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric)
):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
# Save the Trainer state
if self.is_world_process_zero():
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
# Maybe delete some older checkpoints.
if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True)
def hyperparameter_search(
self,
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
@@ -1164,11 +1208,13 @@ class Trainer:
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
xm.rendezvous("saving_checkpoint")
self.model.save_pretrained(output_dir)
if not isinstance(self.model, PreTrainedModel):
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
@@ -1179,8 +1225,11 @@ class Trainer:
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
self.model.save_pretrained(output_dir)
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
@@ -1215,6 +1264,13 @@ class Trainer:
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
# Make sure we don't delete the best model.
if self.state.best_model_checkpoint is not None:
best_model_index = checkpoints_sorted.index(self.state.best_model_checkpoint)
checkpoints_sorted[best_model_index], checkpoints_sorted[best_model_index][-1] = (
checkpoints_sorted[-1],
checkpoints_sorted[best_model_index],
)
return checkpoints_sorted
def _rotate_checkpoints(self, use_mtime=False) -> None: