Add automatic best model loading to Trainer (#7431)
* Add automatic best model loading to Trainer * Some small fixes * Formatting
This commit is contained in:
@@ -145,6 +145,28 @@ class TrainingArguments:
|
||||
Will eventually default to :obj:`["labels"]` except if the model used is one of the
|
||||
:obj:`XxxForQuestionAnswering` in which case it will default to
|
||||
:obj:`["start_positions", "end_positions"]`.
|
||||
load_best_model_at_end (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to load the best model found during training at the end of training.
|
||||
|
||||
.. note::
|
||||
|
||||
When set to :obj:`True`, the parameters :obj:`save_steps` will be ignored and the model will be saved
|
||||
after each evaluation.
|
||||
metric_for_best_model (:obj:`str`, `optional`)
|
||||
Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different
|
||||
models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`.
|
||||
Will default to :obj:`"loss"` if unspecified and :obj:`load_best_model_at_end=True` (to use the evaluation
|
||||
loss).
|
||||
|
||||
If you set this value, :obj:`greater_is_better` will defaut to :obj:`True`. Don't forget to set it to
|
||||
:obj:`False` if your metric is better when lower.
|
||||
greater_is_better (:obj:`bool`, `optional`)
|
||||
Use in conjunction with :obj:`load_best_model_at_end` and :obj:`metric_for_best_model` to specify if better
|
||||
models should have a greater metric or not. Will default to:
|
||||
|
||||
- :obj:`True` if :obj:`metric_for_best_model` is set to a value that isn't :obj:`"loss"` or
|
||||
:obj:`"eval_loss"`.
|
||||
- :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`.
|
||||
"""
|
||||
|
||||
output_dir: str = field(
|
||||
@@ -287,6 +309,17 @@ class TrainingArguments:
|
||||
default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
|
||||
)
|
||||
|
||||
load_best_model_at_end: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to load the best model found during training at the end of training."},
|
||||
)
|
||||
metric_for_best_model: Optional[str] = field(
|
||||
default=None, metadata={"help": "The metric to use to compare two different models."}
|
||||
)
|
||||
greater_is_better: Optional[bool] = field(
|
||||
default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.disable_tqdm is None:
|
||||
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
||||
@@ -304,6 +337,11 @@ class TrainingArguments:
|
||||
if self.eval_steps is None:
|
||||
self.eval_steps = self.logging_steps
|
||||
|
||||
if self.load_best_model_at_end and self.metric_for_best_model is None:
|
||||
self.metric_for_best_model = "loss"
|
||||
if self.greater_is_better is None and self.metric_for_best_model is not None:
|
||||
self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"]
|
||||
|
||||
@property
|
||||
def train_batch_size(self) -> int:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user