Add automatic best model loading to Trainer (#7431)
* Add automatic best model loading to Trainer * Some small fixes * Formatting
This commit is contained in:
@@ -20,7 +20,7 @@ from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
|
|||||||
from tqdm.auto import tqdm, trange
|
from tqdm.auto import tqdm, trange
|
||||||
|
|
||||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||||
from .file_utils import is_datasets_available, is_torch_tpu_available
|
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
|
||||||
from .integrations import (
|
from .integrations import (
|
||||||
default_hp_search_backend,
|
default_hp_search_backend,
|
||||||
is_comet_available,
|
is_comet_available,
|
||||||
@@ -42,6 +42,7 @@ from .trainer_utils import (
|
|||||||
EvaluationStrategy,
|
EvaluationStrategy,
|
||||||
HPSearchBackend,
|
HPSearchBackend,
|
||||||
PredictionOutput,
|
PredictionOutput,
|
||||||
|
TrainerState,
|
||||||
TrainOutput,
|
TrainOutput,
|
||||||
default_compute_objective,
|
default_compute_objective,
|
||||||
default_hp_space,
|
default_hp_space,
|
||||||
@@ -642,6 +643,7 @@ class Trainer:
|
|||||||
self.args.max_steps = t_total
|
self.args.max_steps = t_total
|
||||||
|
|
||||||
self.create_optimizer_and_scheduler(num_training_steps=t_total)
|
self.create_optimizer_and_scheduler(num_training_steps=t_total)
|
||||||
|
self.state = TrainerState()
|
||||||
|
|
||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
if (
|
if (
|
||||||
@@ -657,6 +659,10 @@ class Trainer:
|
|||||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
|
# Check if a saved Trainer state exist
|
||||||
|
if model_path is not None and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
|
||||||
|
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
|
||||||
|
|
||||||
model = self.model
|
model = self.model
|
||||||
if self.args.fp16 and _use_apex:
|
if self.args.fp16 and _use_apex:
|
||||||
if not is_apex_available():
|
if not is_apex_available():
|
||||||
@@ -803,44 +809,15 @@ class Trainer:
|
|||||||
):
|
):
|
||||||
metrics = self.evaluate()
|
metrics = self.evaluate()
|
||||||
self._report_to_hp_search(trial, epoch, metrics)
|
self._report_to_hp_search(trial, epoch, metrics)
|
||||||
|
if self.args.load_best_model_at_end:
|
||||||
|
self._save_training(model, trial, metrics=metrics)
|
||||||
|
|
||||||
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
|
if (
|
||||||
# In all cases (even distributed/parallel), self.model is always a reference
|
not self.args.load_best_model_at_end
|
||||||
# to the model we want to save.
|
and self.args.save_steps > 0
|
||||||
if hasattr(model, "module"):
|
and self.global_step % self.args.save_steps == 0
|
||||||
assert (
|
):
|
||||||
model.module is self.model
|
self._save_training(model, trial)
|
||||||
), f"Module {model.module} should be a reference to self.model"
|
|
||||||
else:
|
|
||||||
assert model is self.model, f"Model {model} should be a reference to self.model"
|
|
||||||
# Save model checkpoint
|
|
||||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
|
|
||||||
if self.hp_search_backend is not None and trial is not None:
|
|
||||||
run_id = (
|
|
||||||
trial.number
|
|
||||||
if self.hp_search_backend == HPSearchBackend.OPTUNA
|
|
||||||
else tune.get_trial_id()
|
|
||||||
)
|
|
||||||
checkpoint_folder += f"-run-{run_id}"
|
|
||||||
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
|
|
||||||
|
|
||||||
self.store_flos()
|
|
||||||
self.save_model(output_dir)
|
|
||||||
|
|
||||||
if self.is_world_process_zero():
|
|
||||||
self._rotate_checkpoints(use_mtime=True)
|
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
|
||||||
xm.rendezvous("saving_optimizer_states")
|
|
||||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
|
||||||
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
|
||||||
reissue_pt_warnings(caught_warnings)
|
|
||||||
elif self.is_world_process_zero():
|
|
||||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
|
||||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
|
||||||
reissue_pt_warnings(caught_warnings)
|
|
||||||
|
|
||||||
epoch_pbar.update(1)
|
epoch_pbar.update(1)
|
||||||
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
||||||
@@ -851,6 +828,8 @@ class Trainer:
|
|||||||
if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||||
metrics = self.evaluate()
|
metrics = self.evaluate()
|
||||||
self._report_to_hp_search(trial, epoch, metrics)
|
self._report_to_hp_search(trial, epoch, metrics)
|
||||||
|
if self.args.load_best_model_at_end:
|
||||||
|
self._save_training(model, trial, metrics=metrics)
|
||||||
|
|
||||||
if self.args.tpu_metrics_debug or self.args.debug:
|
if self.args.tpu_metrics_debug or self.args.debug:
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
@@ -872,8 +851,73 @@ class Trainer:
|
|||||||
delattr(self, "_past")
|
delattr(self, "_past")
|
||||||
|
|
||||||
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
||||||
|
if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
|
||||||
|
logger.info(
|
||||||
|
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
|
||||||
|
)
|
||||||
|
if isinstance(model, PreTrainedModel):
|
||||||
|
self.model = model.from_pretrained(self.state.best_model_checkpoint)
|
||||||
|
self.model = self.model.to(self.args.device)
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
||||||
|
self.model.load_state_dict(state_dict)
|
||||||
|
|
||||||
return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
|
return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
|
||||||
|
|
||||||
|
def _save_training(self, model, trial, metrics=None):
|
||||||
|
# In all cases (even distributed/parallel), self.model is always a reference
|
||||||
|
# to the model we want to save.
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
assert model.module is self.model, f"Module {model.module} should be a reference to self.model"
|
||||||
|
else:
|
||||||
|
assert model is self.model, f"Model {model} should be a reference to self.model"
|
||||||
|
# Save model checkpoint
|
||||||
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
|
||||||
|
if self.hp_search_backend is not None and trial is not None:
|
||||||
|
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
|
||||||
|
checkpoint_folder += f"-run-{run_id}"
|
||||||
|
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
|
||||||
|
|
||||||
|
self.store_flos()
|
||||||
|
self.save_model(output_dir)
|
||||||
|
|
||||||
|
# Save optimizer and scheduler
|
||||||
|
if is_torch_tpu_available():
|
||||||
|
xm.rendezvous("saving_optimizer_states")
|
||||||
|
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
elif self.is_world_process_zero():
|
||||||
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
|
# Determine the new best metric / best model checkpoint
|
||||||
|
if metrics is not None:
|
||||||
|
metric_to_check = self.args.metric_for_best_model
|
||||||
|
if not metric_to_check.startswith("eval_"):
|
||||||
|
metric_to_check = f"eval_{metric_to_check}"
|
||||||
|
metric_value = metrics[metric_to_check]
|
||||||
|
|
||||||
|
operator = np.greater if self.args.greater_is_better else np.less
|
||||||
|
if (
|
||||||
|
self.state.best_metric is None
|
||||||
|
or self.state.best_model_checkpoint is None
|
||||||
|
or operator(metric_value, self.state.best_metric)
|
||||||
|
):
|
||||||
|
self.state.best_metric = metric_value
|
||||||
|
self.state.best_model_checkpoint = output_dir
|
||||||
|
|
||||||
|
# Save the Trainer state
|
||||||
|
if self.is_world_process_zero():
|
||||||
|
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
||||||
|
|
||||||
|
# Maybe delete some older checkpoints.
|
||||||
|
if self.is_world_process_zero():
|
||||||
|
self._rotate_checkpoints(use_mtime=True)
|
||||||
|
|
||||||
def hyperparameter_search(
|
def hyperparameter_search(
|
||||||
self,
|
self,
|
||||||
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
|
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
|
||||||
@@ -1164,11 +1208,13 @@ class Trainer:
|
|||||||
|
|
||||||
# Save a trained model and configuration using `save_pretrained()`.
|
# Save a trained model and configuration using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
if not isinstance(self.model, PreTrainedModel):
|
|
||||||
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
|
||||||
|
|
||||||
xm.rendezvous("saving_checkpoint")
|
xm.rendezvous("saving_checkpoint")
|
||||||
self.model.save_pretrained(output_dir)
|
if not isinstance(self.model, PreTrainedModel):
|
||||||
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||||
|
else:
|
||||||
|
self.model.save_pretrained(output_dir)
|
||||||
if self.tokenizer is not None:
|
if self.tokenizer is not None:
|
||||||
self.tokenizer.save_pretrained(output_dir)
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
@@ -1179,8 +1225,11 @@ class Trainer:
|
|||||||
# Save a trained model and configuration using `save_pretrained()`.
|
# Save a trained model and configuration using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
if not isinstance(self.model, PreTrainedModel):
|
if not isinstance(self.model, PreTrainedModel):
|
||||||
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||||
self.model.save_pretrained(output_dir)
|
state_dict = self.model.state_dict()
|
||||||
|
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||||
|
else:
|
||||||
|
self.model.save_pretrained(output_dir)
|
||||||
if self.tokenizer is not None:
|
if self.tokenizer is not None:
|
||||||
self.tokenizer.save_pretrained(output_dir)
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
@@ -1215,6 +1264,13 @@ class Trainer:
|
|||||||
|
|
||||||
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
|
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
|
||||||
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
|
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
|
||||||
|
# Make sure we don't delete the best model.
|
||||||
|
if self.state.best_model_checkpoint is not None:
|
||||||
|
best_model_index = checkpoints_sorted.index(self.state.best_model_checkpoint)
|
||||||
|
checkpoints_sorted[best_model_index], checkpoints_sorted[best_model_index][-1] = (
|
||||||
|
checkpoints_sorted[-1],
|
||||||
|
checkpoints_sorted[best_model_index],
|
||||||
|
)
|
||||||
return checkpoints_sorted
|
return checkpoints_sorted
|
||||||
|
|
||||||
def _rotate_checkpoints(self, use_mtime=False) -> None:
|
def _rotate_checkpoints(self, use_mtime=False) -> None:
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
|
import dataclasses
|
||||||
|
import json
|
||||||
import random
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -213,3 +216,26 @@ def distributed_broadcast_scalars(
|
|||||||
raise AssertionError("Not currently using distributed training")
|
raise AssertionError("Not currently using distributed training")
|
||||||
else:
|
else:
|
||||||
raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`")
|
raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainerState:
|
||||||
|
"""
|
||||||
|
A class containing the `Trainer` fields that will be saved along the model and optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
best_metric: Optional[float] = None
|
||||||
|
best_model_checkpoint: Optional[str] = None
|
||||||
|
|
||||||
|
def save_to_json(self, json_path: str):
|
||||||
|
""" Save the content of this instance in JSON format inside :obj:`json_path`."""
|
||||||
|
json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(json_string)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_from_json(cls, json_path: str):
|
||||||
|
""" Create an instance from the content of :obj:`json_path`."""
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
|
text = f.read()
|
||||||
|
return cls(**json.loads(text))
|
||||||
|
|||||||
@@ -145,6 +145,28 @@ class TrainingArguments:
|
|||||||
Will eventually default to :obj:`["labels"]` except if the model used is one of the
|
Will eventually default to :obj:`["labels"]` except if the model used is one of the
|
||||||
:obj:`XxxForQuestionAnswering` in which case it will default to
|
:obj:`XxxForQuestionAnswering` in which case it will default to
|
||||||
:obj:`["start_positions", "end_positions"]`.
|
:obj:`["start_positions", "end_positions"]`.
|
||||||
|
load_best_model_at_end (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to load the best model found during training at the end of training.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
When set to :obj:`True`, the parameters :obj:`save_steps` will be ignored and the model will be saved
|
||||||
|
after each evaluation.
|
||||||
|
metric_for_best_model (:obj:`str`, `optional`)
|
||||||
|
Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different
|
||||||
|
models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`.
|
||||||
|
Will default to :obj:`"loss"` if unspecified and :obj:`load_best_model_at_end=True` (to use the evaluation
|
||||||
|
loss).
|
||||||
|
|
||||||
|
If you set this value, :obj:`greater_is_better` will defaut to :obj:`True`. Don't forget to set it to
|
||||||
|
:obj:`False` if your metric is better when lower.
|
||||||
|
greater_is_better (:obj:`bool`, `optional`)
|
||||||
|
Use in conjunction with :obj:`load_best_model_at_end` and :obj:`metric_for_best_model` to specify if better
|
||||||
|
models should have a greater metric or not. Will default to:
|
||||||
|
|
||||||
|
- :obj:`True` if :obj:`metric_for_best_model` is set to a value that isn't :obj:`"loss"` or
|
||||||
|
:obj:`"eval_loss"`.
|
||||||
|
- :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@@ -287,6 +309,17 @@ class TrainingArguments:
|
|||||||
default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
|
default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
load_best_model_at_end: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to load the best model found during training at the end of training."},
|
||||||
|
)
|
||||||
|
metric_for_best_model: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "The metric to use to compare two different models."}
|
||||||
|
)
|
||||||
|
greater_is_better: Optional[bool] = field(
|
||||||
|
default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."}
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.disable_tqdm is None:
|
if self.disable_tqdm is None:
|
||||||
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
||||||
@@ -304,6 +337,11 @@ class TrainingArguments:
|
|||||||
if self.eval_steps is None:
|
if self.eval_steps is None:
|
||||||
self.eval_steps = self.logging_steps
|
self.eval_steps = self.logging_steps
|
||||||
|
|
||||||
|
if self.load_best_model_at_end and self.metric_for_best_model is None:
|
||||||
|
self.metric_for_best_model = "loss"
|
||||||
|
if self.greater_is_better is None and self.metric_for_best_model is not None:
|
||||||
|
self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def train_batch_size(self) -> int:
|
def train_batch_size(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers import AutoTokenizer, TrainingArguments, is_torch_available
|
from transformers import AutoTokenizer, PretrainedConfig, TrainingArguments, is_torch_available
|
||||||
|
from transformers.file_utils import WEIGHTS_NAME
|
||||||
from transformers.testing_utils import get_tests_dir, require_torch, slow
|
from transformers.testing_utils import get_tests_dir, require_torch, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -16,6 +20,7 @@ if is_torch_available():
|
|||||||
GlueDataset,
|
GlueDataset,
|
||||||
GlueDataTrainingArguments,
|
GlueDataTrainingArguments,
|
||||||
LineByLineTextDataset,
|
LineByLineTextDataset,
|
||||||
|
PreTrainedModel,
|
||||||
Trainer,
|
Trainer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -51,6 +56,14 @@ class AlmostAccuracy:
|
|||||||
return {"accuracy": true.astype(np.float32).mean().item()}
|
return {"accuracy": true.astype(np.float32).mean().item()}
|
||||||
|
|
||||||
|
|
||||||
|
class RegressionModelConfig(PretrainedConfig):
|
||||||
|
def __init__(self, a=0, b=0, double_output=False, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.a = a
|
||||||
|
self.b = b
|
||||||
|
self.double_output = double_output
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
|
||||||
class SampleIterableDataset(IterableDataset):
|
class SampleIterableDataset(IterableDataset):
|
||||||
@@ -79,15 +92,34 @@ if is_torch_available():
|
|||||||
loss = torch.nn.functional.mse_loss(y, labels)
|
loss = torch.nn.functional.mse_loss(y, labels)
|
||||||
return (loss, y, y) if self.double_output else (loss, y)
|
return (loss, y, y) if self.double_output else (loss, y)
|
||||||
|
|
||||||
|
class RegressionPreTrainedModel(PreTrainedModel):
|
||||||
|
config_class = RegressionModelConfig
|
||||||
|
base_model_prefix = "regression"
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.a = torch.nn.Parameter(torch.tensor(config.a).float())
|
||||||
|
self.b = torch.nn.Parameter(torch.tensor(config.b).float())
|
||||||
|
self.double_output = config.double_output
|
||||||
|
|
||||||
|
def forward(self, input_x=None, labels=None, **kwargs):
|
||||||
|
y = input_x * self.a + self.b
|
||||||
|
if labels is None:
|
||||||
|
return (y, y) if self.double_output else (y,)
|
||||||
|
loss = torch.nn.functional.mse_loss(y, labels)
|
||||||
|
return (loss, y, y) if self.double_output else (loss, y)
|
||||||
|
|
||||||
def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs):
|
def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs):
|
||||||
label_names = kwargs.get("label_names", None)
|
label_names = kwargs.get("label_names", None)
|
||||||
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
|
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
|
||||||
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
|
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
|
||||||
model = RegressionModel(a, b, double_output)
|
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
|
||||||
|
model = RegressionPreTrainedModel(config)
|
||||||
compute_metrics = kwargs.pop("compute_metrics", None)
|
compute_metrics = kwargs.pop("compute_metrics", None)
|
||||||
data_collator = kwargs.pop("data_collator", None)
|
data_collator = kwargs.pop("data_collator", None)
|
||||||
optimizers = kwargs.pop("optimizers", (None, None))
|
optimizers = kwargs.pop("optimizers", (None, None))
|
||||||
args = TrainingArguments("./regression", **kwargs)
|
output_dir = kwargs.pop("output_dir", "./regression")
|
||||||
|
args = TrainingArguments(output_dir, **kwargs)
|
||||||
return Trainer(
|
return Trainer(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
@@ -119,6 +151,39 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(model.a, a))
|
self.assertTrue(torch.allclose(model.a, a))
|
||||||
self.assertTrue(torch.allclose(model.b, b))
|
self.assertTrue(torch.allclose(model.b, b))
|
||||||
|
|
||||||
|
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
|
||||||
|
file_list = [WEIGHTS_NAME, "training_args.bin", "log_history.json", "optimizer.pt", "scheduler.pt"]
|
||||||
|
if is_pretrained:
|
||||||
|
file_list.append("config.json")
|
||||||
|
for step in range(freq, total, freq):
|
||||||
|
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
|
||||||
|
self.assertTrue(os.path.isdir(checkpoint))
|
||||||
|
for filename in file_list:
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
|
||||||
|
|
||||||
|
def check_best_model_has_been_loaded(
|
||||||
|
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True
|
||||||
|
):
|
||||||
|
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
|
||||||
|
log_history = json.load(open(os.path.join(checkpoint, "log_history.json")))
|
||||||
|
|
||||||
|
values = [d[metric] for d in log_history]
|
||||||
|
best_value = max(values) if greater_is_better else min(values)
|
||||||
|
best_checkpoint = (values.index(best_value) + 1) * freq
|
||||||
|
checkpoint = os.path.join(output_dir, f"checkpoint-{best_checkpoint}")
|
||||||
|
if is_pretrained:
|
||||||
|
best_model = RegressionPreTrainedModel.from_pretrained(checkpoint)
|
||||||
|
best_model.to(trainer.args.device)
|
||||||
|
else:
|
||||||
|
best_model = RegressionModel()
|
||||||
|
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
|
||||||
|
best_model.load_state_dict(state_dict)
|
||||||
|
self.assertTrue(torch.allclose(best_model.a, trainer.model.a))
|
||||||
|
self.assertTrue(torch.allclose(best_model.b, trainer.model.b))
|
||||||
|
|
||||||
|
metrics = trainer.evaluate()
|
||||||
|
self.assertEqual(metrics[metric], best_value)
|
||||||
|
|
||||||
def test_reproducible_training(self):
|
def test_reproducible_training(self):
|
||||||
# Checks that training worked, model trained and seed made a reproducible training.
|
# Checks that training worked, model trained and seed made a reproducible training.
|
||||||
trainer = get_regression_trainer(learning_rate=0.1)
|
trainer = get_regression_trainer(learning_rate=0.1)
|
||||||
@@ -287,6 +352,87 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||||
|
|
||||||
|
def test_save_checkpoints(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
|
||||||
|
trainer.train()
|
||||||
|
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size))
|
||||||
|
|
||||||
|
# With a regular model that is not a PreTrainedModel
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
|
||||||
|
trainer.model = RegressionModel()
|
||||||
|
trainer.train()
|
||||||
|
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
|
||||||
|
|
||||||
|
def test_load_best_model_at_end(self):
|
||||||
|
total = int(self.n_epochs * 64 / self.batch_size)
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5,
|
||||||
|
b=2.5,
|
||||||
|
output_dir=tmpdir,
|
||||||
|
learning_rate=0.1,
|
||||||
|
eval_steps=5,
|
||||||
|
evaluation_strategy="steps",
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
)
|
||||||
|
self.assertFalse(trainer.args.greater_is_better)
|
||||||
|
trainer.train()
|
||||||
|
self.check_saved_checkpoints(tmpdir, 5, total)
|
||||||
|
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5,
|
||||||
|
b=2.5,
|
||||||
|
output_dir=tmpdir,
|
||||||
|
learning_rate=0.1,
|
||||||
|
eval_steps=5,
|
||||||
|
evaluation_strategy="steps",
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
metric_for_best_model="accuracy",
|
||||||
|
compute_metrics=AlmostAccuracy(),
|
||||||
|
)
|
||||||
|
self.assertTrue(trainer.args.greater_is_better)
|
||||||
|
trainer.train()
|
||||||
|
self.check_saved_checkpoints(tmpdir, 5, total)
|
||||||
|
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_accuracy", greater_is_better=True)
|
||||||
|
|
||||||
|
# Save is done every eval regardless of the strategy
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5,
|
||||||
|
b=2.5,
|
||||||
|
output_dir=tmpdir,
|
||||||
|
learning_rate=0.1,
|
||||||
|
evaluation_strategy="epoch",
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
metric_for_best_model="accuracy",
|
||||||
|
compute_metrics=AlmostAccuracy(),
|
||||||
|
)
|
||||||
|
self.assertTrue(trainer.args.greater_is_better)
|
||||||
|
trainer.train()
|
||||||
|
self.check_saved_checkpoints(tmpdir, 64 // self.batch_size, total)
|
||||||
|
self.check_best_model_has_been_loaded(
|
||||||
|
tmpdir, 64 // self.batch_size, total, trainer, "eval_accuracy", greater_is_better=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test this works with a non PreTrainedModel
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
output_dir=tmpdir,
|
||||||
|
learning_rate=0.1,
|
||||||
|
eval_steps=5,
|
||||||
|
evaluation_strategy="steps",
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
)
|
||||||
|
trainer.model = RegressionModel(a=1.5, b=2.5)
|
||||||
|
self.assertFalse(trainer.args.greater_is_better)
|
||||||
|
trainer.train()
|
||||||
|
self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False)
|
||||||
|
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss", is_pretrained=False)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_trainer_eval_mrpc(self):
|
def test_trainer_eval_mrpc(self):
|
||||||
MODEL_ID = "bert-base-cased-finetuned-mrpc"
|
MODEL_ID = "bert-base-cased-finetuned-mrpc"
|
||||||
|
|||||||
Reference in New Issue
Block a user