From 52e8392b7ebd4ebc7b796e8f14b9dae271139f5f Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 29 Sep 2020 10:41:18 -0400 Subject: [PATCH] Add automatic best model loading to Trainer (#7431) * Add automatic best model loading to Trainer * Some small fixes * Formatting --- src/transformers/trainer.py | 144 +++++++++++++++++++--------- src/transformers/trainer_utils.py | 26 +++++ src/transformers/training_args.py | 38 ++++++++ tests/test_trainer.py | 152 +++++++++++++++++++++++++++++- 4 files changed, 313 insertions(+), 47 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7d9a093d6c..ea27b1604c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -20,7 +20,7 @@ from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler from tqdm.auto import tqdm, trange from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator -from .file_utils import is_datasets_available, is_torch_tpu_available +from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available from .integrations import ( default_hp_search_backend, is_comet_available, @@ -42,6 +42,7 @@ from .trainer_utils import ( EvaluationStrategy, HPSearchBackend, PredictionOutput, + TrainerState, TrainOutput, default_compute_objective, default_hp_space, @@ -642,6 +643,7 @@ class Trainer: self.args.max_steps = t_total self.create_optimizer_and_scheduler(num_training_steps=t_total) + self.state = TrainerState() # Check if saved optimizer or scheduler states exist if ( @@ -657,6 +659,10 @@ class Trainer: self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) reissue_pt_warnings(caught_warnings) + # Check if a saved Trainer state exist + if model_path is not None and os.path.isfile(os.path.join(model_path, "trainer_state.json")): + self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json")) + model = self.model if self.args.fp16 and _use_apex: if not is_apex_available(): @@ -803,44 +809,15 @@ class Trainer: ): metrics = self.evaluate() self._report_to_hp_search(trial, epoch, metrics) + if self.args.load_best_model_at_end: + self._save_training(model, trial, metrics=metrics) - if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: - # In all cases (even distributed/parallel), self.model is always a reference - # to the model we want to save. - if hasattr(model, "module"): - assert ( - model.module is self.model - ), f"Module {model.module} should be a reference to self.model" - else: - assert model is self.model, f"Model {model} should be a reference to self.model" - # Save model checkpoint - checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}" - if self.hp_search_backend is not None and trial is not None: - run_id = ( - trial.number - if self.hp_search_backend == HPSearchBackend.OPTUNA - else tune.get_trial_id() - ) - checkpoint_folder += f"-run-{run_id}" - output_dir = os.path.join(self.args.output_dir, checkpoint_folder) - - self.store_flos() - self.save_model(output_dir) - - if self.is_world_process_zero(): - self._rotate_checkpoints(use_mtime=True) - - if is_torch_tpu_available(): - xm.rendezvous("saving_optimizer_states") - xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) - with warnings.catch_warnings(record=True) as caught_warnings: - xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) - reissue_pt_warnings(caught_warnings) - elif self.is_world_process_zero(): - torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) - with warnings.catch_warnings(record=True) as caught_warnings: - torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) - reissue_pt_warnings(caught_warnings) + if ( + not self.args.load_best_model_at_end + and self.args.save_steps > 0 + and self.global_step % self.args.save_steps == 0 + ): + self._save_training(model, trial) epoch_pbar.update(1) if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: @@ -851,6 +828,8 @@ class Trainer: if self.args.evaluation_strategy == EvaluationStrategy.EPOCH: metrics = self.evaluate() self._report_to_hp_search(trial, epoch, metrics) + if self.args.load_best_model_at_end: + self._save_training(model, trial, metrics=metrics) if self.args.tpu_metrics_debug or self.args.debug: if is_torch_tpu_available(): @@ -872,8 +851,73 @@ class Trainer: delattr(self, "_past") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") + if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + logger.info( + f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." + ) + if isinstance(model, PreTrainedModel): + self.model = model.from_pretrained(self.state.best_model_checkpoint) + self.model = self.model.to(self.args.device) + else: + state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) + self.model.load_state_dict(state_dict) + return TrainOutput(self.global_step, tr_loss.item() / self.global_step) + def _save_training(self, model, trial, metrics=None): + # In all cases (even distributed/parallel), self.model is always a reference + # to the model we want to save. + if hasattr(model, "module"): + assert model.module is self.model, f"Module {model.module} should be a reference to self.model" + else: + assert model is self.model, f"Model {model} should be a reference to self.model" + # Save model checkpoint + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}" + if self.hp_search_backend is not None and trial is not None: + run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id() + checkpoint_folder += f"-run-{run_id}" + output_dir = os.path.join(self.args.output_dir, checkpoint_folder) + + self.store_flos() + self.save_model(output_dir) + + # Save optimizer and scheduler + if is_torch_tpu_available(): + xm.rendezvous("saving_optimizer_states") + xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) + with warnings.catch_warnings(record=True) as caught_warnings: + xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) + reissue_pt_warnings(caught_warnings) + elif self.is_world_process_zero(): + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) + reissue_pt_warnings(caught_warnings) + + # Determine the new best metric / best model checkpoint + if metrics is not None: + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + metric_value = metrics[metric_to_check] + + operator = np.greater if self.args.greater_is_better else np.less + if ( + self.state.best_metric is None + or self.state.best_model_checkpoint is None + or operator(metric_value, self.state.best_metric) + ): + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + + # Save the Trainer state + if self.is_world_process_zero(): + self.state.save_to_json(os.path.join(output_dir, "trainer_state.json")) + + # Maybe delete some older checkpoints. + if self.is_world_process_zero(): + self._rotate_checkpoints(use_mtime=True) + def hyperparameter_search( self, hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, @@ -1164,11 +1208,13 @@ class Trainer: # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` - if not isinstance(self.model, PreTrainedModel): - raise ValueError("Trainer.model appears to not be a PreTrainedModel") - xm.rendezvous("saving_checkpoint") - self.model.save_pretrained(output_dir) + if not isinstance(self.model, PreTrainedModel): + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + state_dict = self.model.state_dict() + xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained(output_dir) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) @@ -1179,8 +1225,11 @@ class Trainer: # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, PreTrainedModel): - raise ValueError("Trainer.model appears to not be a PreTrainedModel") - self.model.save_pretrained(output_dir) + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + state_dict = self.model.state_dict() + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained(output_dir) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) @@ -1215,6 +1264,13 @@ class Trainer: checkpoints_sorted = sorted(ordering_and_checkpoint_path) checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] + # Make sure we don't delete the best model. + if self.state.best_model_checkpoint is not None: + best_model_index = checkpoints_sorted.index(self.state.best_model_checkpoint) + checkpoints_sorted[best_model_index], checkpoints_sorted[best_model_index][-1] = ( + checkpoints_sorted[-1], + checkpoints_sorted[best_model_index], + ) return checkpoints_sorted def _rotate_checkpoints(self, use_mtime=False) -> None: diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 76e69ad559..d93adda186 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -1,4 +1,7 @@ +import dataclasses +import json import random +from dataclasses import dataclass from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union import numpy as np @@ -213,3 +216,26 @@ def distributed_broadcast_scalars( raise AssertionError("Not currently using distributed training") else: raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`") + + +@dataclass +class TrainerState: + """ + A class containing the `Trainer` fields that will be saved along the model and optimizer. + """ + + best_metric: Optional[float] = None + best_model_checkpoint: Optional[str] = None + + def save_to_json(self, json_path: str): + """ Save the content of this instance in JSON format inside :obj:`json_path`.""" + json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n" + with open(json_path, "w", encoding="utf-8") as f: + f.write(json_string) + + @classmethod + def load_from_json(cls, json_path: str): + """ Create an instance from the content of :obj:`json_path`.""" + with open(json_path, "r", encoding="utf-8") as f: + text = f.read() + return cls(**json.loads(text)) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 737816383c..a1f0335646 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -145,6 +145,28 @@ class TrainingArguments: Will eventually default to :obj:`["labels"]` except if the model used is one of the :obj:`XxxForQuestionAnswering` in which case it will default to :obj:`["start_positions", "end_positions"]`. + load_best_model_at_end (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to load the best model found during training at the end of training. + + .. note:: + + When set to :obj:`True`, the parameters :obj:`save_steps` will be ignored and the model will be saved + after each evaluation. + metric_for_best_model (:obj:`str`, `optional`) + Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different + models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`. + Will default to :obj:`"loss"` if unspecified and :obj:`load_best_model_at_end=True` (to use the evaluation + loss). + + If you set this value, :obj:`greater_is_better` will defaut to :obj:`True`. Don't forget to set it to + :obj:`False` if your metric is better when lower. + greater_is_better (:obj:`bool`, `optional`) + Use in conjunction with :obj:`load_best_model_at_end` and :obj:`metric_for_best_model` to specify if better + models should have a greater metric or not. Will default to: + + - :obj:`True` if :obj:`metric_for_best_model` is set to a value that isn't :obj:`"loss"` or + :obj:`"eval_loss"`. + - :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`. """ output_dir: str = field( @@ -287,6 +309,17 @@ class TrainingArguments: default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."} ) + load_best_model_at_end: Optional[bool] = field( + default=False, + metadata={"help": "Whether or not to load the best model found during training at the end of training."}, + ) + metric_for_best_model: Optional[str] = field( + default=None, metadata={"help": "The metric to use to compare two different models."} + ) + greater_is_better: Optional[bool] = field( + default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."} + ) + def __post_init__(self): if self.disable_tqdm is None: self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN @@ -304,6 +337,11 @@ class TrainingArguments: if self.eval_steps is None: self.eval_steps = self.logging_steps + if self.load_best_model_at_end and self.metric_for_best_model is None: + self.metric_for_best_model = "loss" + if self.greater_is_better is None and self.metric_for_best_model is not None: + self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"] + @property def train_batch_size(self) -> int: """ diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 4613b284bd..618f655890 100755 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1,9 +1,13 @@ +import json +import os +import tempfile import unittest import datasets import numpy as np -from transformers import AutoTokenizer, TrainingArguments, is_torch_available +from transformers import AutoTokenizer, PretrainedConfig, TrainingArguments, is_torch_available +from transformers.file_utils import WEIGHTS_NAME from transformers.testing_utils import get_tests_dir, require_torch, slow @@ -16,6 +20,7 @@ if is_torch_available(): GlueDataset, GlueDataTrainingArguments, LineByLineTextDataset, + PreTrainedModel, Trainer, ) @@ -51,6 +56,14 @@ class AlmostAccuracy: return {"accuracy": true.astype(np.float32).mean().item()} +class RegressionModelConfig(PretrainedConfig): + def __init__(self, a=0, b=0, double_output=False, **kwargs): + super().__init__(**kwargs) + self.a = a + self.b = b + self.double_output = double_output + + if is_torch_available(): class SampleIterableDataset(IterableDataset): @@ -79,15 +92,34 @@ if is_torch_available(): loss = torch.nn.functional.mse_loss(y, labels) return (loss, y, y) if self.double_output else (loss, y) + class RegressionPreTrainedModel(PreTrainedModel): + config_class = RegressionModelConfig + base_model_prefix = "regression" + + def __init__(self, config): + super().__init__(config) + self.a = torch.nn.Parameter(torch.tensor(config.a).float()) + self.b = torch.nn.Parameter(torch.tensor(config.b).float()) + self.double_output = config.double_output + + def forward(self, input_x=None, labels=None, **kwargs): + y = input_x * self.a + self.b + if labels is None: + return (y, y) if self.double_output else (y,) + loss = torch.nn.functional.mse_loss(y, labels) + return (loss, y, y) if self.double_output else (loss, y) + def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs): label_names = kwargs.get("label_names", None) train_dataset = RegressionDataset(length=train_len, label_names=label_names) eval_dataset = RegressionDataset(length=eval_len, label_names=label_names) - model = RegressionModel(a, b, double_output) + config = RegressionModelConfig(a=a, b=b, double_output=double_output) + model = RegressionPreTrainedModel(config) compute_metrics = kwargs.pop("compute_metrics", None) data_collator = kwargs.pop("data_collator", None) optimizers = kwargs.pop("optimizers", (None, None)) - args = TrainingArguments("./regression", **kwargs) + output_dir = kwargs.pop("output_dir", "./regression") + args = TrainingArguments(output_dir, **kwargs) return Trainer( model, args, @@ -119,6 +151,39 @@ class TrainerIntegrationTest(unittest.TestCase): self.assertTrue(torch.allclose(model.a, a)) self.assertTrue(torch.allclose(model.b, b)) + def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True): + file_list = [WEIGHTS_NAME, "training_args.bin", "log_history.json", "optimizer.pt", "scheduler.pt"] + if is_pretrained: + file_list.append("config.json") + for step in range(freq, total, freq): + checkpoint = os.path.join(output_dir, f"checkpoint-{step}") + self.assertTrue(os.path.isdir(checkpoint)) + for filename in file_list: + self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename))) + + def check_best_model_has_been_loaded( + self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True + ): + checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}") + log_history = json.load(open(os.path.join(checkpoint, "log_history.json"))) + + values = [d[metric] for d in log_history] + best_value = max(values) if greater_is_better else min(values) + best_checkpoint = (values.index(best_value) + 1) * freq + checkpoint = os.path.join(output_dir, f"checkpoint-{best_checkpoint}") + if is_pretrained: + best_model = RegressionPreTrainedModel.from_pretrained(checkpoint) + best_model.to(trainer.args.device) + else: + best_model = RegressionModel() + state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME)) + best_model.load_state_dict(state_dict) + self.assertTrue(torch.allclose(best_model.a, trainer.model.a)) + self.assertTrue(torch.allclose(best_model.b, trainer.model.b)) + + metrics = trainer.evaluate() + self.assertEqual(metrics[metric], best_value) + def test_reproducible_training(self): # Checks that training worked, model trained and seed made a reproducible training. trainer = get_regression_trainer(learning_rate=0.1) @@ -287,6 +352,87 @@ class TrainerIntegrationTest(unittest.TestCase): trainer.train() self.check_trained_model(trainer.model, alternate_seed=True) + def test_save_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5) + trainer.train() + self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size)) + + # With a regular model that is not a PreTrainedModel + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5) + trainer.model = RegressionModel() + trainer.train() + self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False) + + def test_load_best_model_at_end(self): + total = int(self.n_epochs * 64 / self.batch_size) + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + a=1.5, + b=2.5, + output_dir=tmpdir, + learning_rate=0.1, + eval_steps=5, + evaluation_strategy="steps", + load_best_model_at_end=True, + ) + self.assertFalse(trainer.args.greater_is_better) + trainer.train() + self.check_saved_checkpoints(tmpdir, 5, total) + self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss") + + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + a=1.5, + b=2.5, + output_dir=tmpdir, + learning_rate=0.1, + eval_steps=5, + evaluation_strategy="steps", + load_best_model_at_end=True, + metric_for_best_model="accuracy", + compute_metrics=AlmostAccuracy(), + ) + self.assertTrue(trainer.args.greater_is_better) + trainer.train() + self.check_saved_checkpoints(tmpdir, 5, total) + self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_accuracy", greater_is_better=True) + + # Save is done every eval regardless of the strategy + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + a=1.5, + b=2.5, + output_dir=tmpdir, + learning_rate=0.1, + evaluation_strategy="epoch", + load_best_model_at_end=True, + metric_for_best_model="accuracy", + compute_metrics=AlmostAccuracy(), + ) + self.assertTrue(trainer.args.greater_is_better) + trainer.train() + self.check_saved_checkpoints(tmpdir, 64 // self.batch_size, total) + self.check_best_model_has_been_loaded( + tmpdir, 64 // self.batch_size, total, trainer, "eval_accuracy", greater_is_better=True + ) + + # Test this works with a non PreTrainedModel + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + output_dir=tmpdir, + learning_rate=0.1, + eval_steps=5, + evaluation_strategy="steps", + load_best_model_at_end=True, + ) + trainer.model = RegressionModel(a=1.5, b=2.5) + self.assertFalse(trainer.args.greater_is_better) + trainer.train() + self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False) + self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss", is_pretrained=False) + @slow def test_trainer_eval_mrpc(self): MODEL_ID = "bert-base-cased-finetuned-mrpc"