Add automatic best model loading to Trainer (#7431)
* Add automatic best model loading to Trainer * Some small fixes * Formatting
This commit is contained in:
@@ -1,9 +1,13 @@
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoTokenizer, TrainingArguments, is_torch_available
|
||||
from transformers import AutoTokenizer, PretrainedConfig, TrainingArguments, is_torch_available
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.testing_utils import get_tests_dir, require_torch, slow
|
||||
|
||||
|
||||
@@ -16,6 +20,7 @@ if is_torch_available():
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
LineByLineTextDataset,
|
||||
PreTrainedModel,
|
||||
Trainer,
|
||||
)
|
||||
|
||||
@@ -51,6 +56,14 @@ class AlmostAccuracy:
|
||||
return {"accuracy": true.astype(np.float32).mean().item()}
|
||||
|
||||
|
||||
class RegressionModelConfig(PretrainedConfig):
|
||||
def __init__(self, a=0, b=0, double_output=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.a = a
|
||||
self.b = b
|
||||
self.double_output = double_output
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
||||
class SampleIterableDataset(IterableDataset):
|
||||
@@ -79,15 +92,34 @@ if is_torch_available():
|
||||
loss = torch.nn.functional.mse_loss(y, labels)
|
||||
return (loss, y, y) if self.double_output else (loss, y)
|
||||
|
||||
class RegressionPreTrainedModel(PreTrainedModel):
|
||||
config_class = RegressionModelConfig
|
||||
base_model_prefix = "regression"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.a = torch.nn.Parameter(torch.tensor(config.a).float())
|
||||
self.b = torch.nn.Parameter(torch.tensor(config.b).float())
|
||||
self.double_output = config.double_output
|
||||
|
||||
def forward(self, input_x=None, labels=None, **kwargs):
|
||||
y = input_x * self.a + self.b
|
||||
if labels is None:
|
||||
return (y, y) if self.double_output else (y,)
|
||||
loss = torch.nn.functional.mse_loss(y, labels)
|
||||
return (loss, y, y) if self.double_output else (loss, y)
|
||||
|
||||
def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs):
|
||||
label_names = kwargs.get("label_names", None)
|
||||
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
|
||||
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
|
||||
model = RegressionModel(a, b, double_output)
|
||||
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
compute_metrics = kwargs.pop("compute_metrics", None)
|
||||
data_collator = kwargs.pop("data_collator", None)
|
||||
optimizers = kwargs.pop("optimizers", (None, None))
|
||||
args = TrainingArguments("./regression", **kwargs)
|
||||
output_dir = kwargs.pop("output_dir", "./regression")
|
||||
args = TrainingArguments(output_dir, **kwargs)
|
||||
return Trainer(
|
||||
model,
|
||||
args,
|
||||
@@ -119,6 +151,39 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(model.a, a))
|
||||
self.assertTrue(torch.allclose(model.b, b))
|
||||
|
||||
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
|
||||
file_list = [WEIGHTS_NAME, "training_args.bin", "log_history.json", "optimizer.pt", "scheduler.pt"]
|
||||
if is_pretrained:
|
||||
file_list.append("config.json")
|
||||
for step in range(freq, total, freq):
|
||||
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
|
||||
self.assertTrue(os.path.isdir(checkpoint))
|
||||
for filename in file_list:
|
||||
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
|
||||
|
||||
def check_best_model_has_been_loaded(
|
||||
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True
|
||||
):
|
||||
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
|
||||
log_history = json.load(open(os.path.join(checkpoint, "log_history.json")))
|
||||
|
||||
values = [d[metric] for d in log_history]
|
||||
best_value = max(values) if greater_is_better else min(values)
|
||||
best_checkpoint = (values.index(best_value) + 1) * freq
|
||||
checkpoint = os.path.join(output_dir, f"checkpoint-{best_checkpoint}")
|
||||
if is_pretrained:
|
||||
best_model = RegressionPreTrainedModel.from_pretrained(checkpoint)
|
||||
best_model.to(trainer.args.device)
|
||||
else:
|
||||
best_model = RegressionModel()
|
||||
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
|
||||
best_model.load_state_dict(state_dict)
|
||||
self.assertTrue(torch.allclose(best_model.a, trainer.model.a))
|
||||
self.assertTrue(torch.allclose(best_model.b, trainer.model.b))
|
||||
|
||||
metrics = trainer.evaluate()
|
||||
self.assertEqual(metrics[metric], best_value)
|
||||
|
||||
def test_reproducible_training(self):
|
||||
# Checks that training worked, model trained and seed made a reproducible training.
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
@@ -287,6 +352,87 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||
|
||||
def test_save_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
|
||||
trainer.train()
|
||||
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size))
|
||||
|
||||
# With a regular model that is not a PreTrainedModel
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
|
||||
trainer.model = RegressionModel()
|
||||
trainer.train()
|
||||
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
|
||||
|
||||
def test_load_best_model_at_end(self):
|
||||
total = int(self.n_epochs * 64 / self.batch_size)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
b=2.5,
|
||||
output_dir=tmpdir,
|
||||
learning_rate=0.1,
|
||||
eval_steps=5,
|
||||
evaluation_strategy="steps",
|
||||
load_best_model_at_end=True,
|
||||
)
|
||||
self.assertFalse(trainer.args.greater_is_better)
|
||||
trainer.train()
|
||||
self.check_saved_checkpoints(tmpdir, 5, total)
|
||||
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
b=2.5,
|
||||
output_dir=tmpdir,
|
||||
learning_rate=0.1,
|
||||
eval_steps=5,
|
||||
evaluation_strategy="steps",
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="accuracy",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
)
|
||||
self.assertTrue(trainer.args.greater_is_better)
|
||||
trainer.train()
|
||||
self.check_saved_checkpoints(tmpdir, 5, total)
|
||||
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_accuracy", greater_is_better=True)
|
||||
|
||||
# Save is done every eval regardless of the strategy
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
b=2.5,
|
||||
output_dir=tmpdir,
|
||||
learning_rate=0.1,
|
||||
evaluation_strategy="epoch",
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="accuracy",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
)
|
||||
self.assertTrue(trainer.args.greater_is_better)
|
||||
trainer.train()
|
||||
self.check_saved_checkpoints(tmpdir, 64 // self.batch_size, total)
|
||||
self.check_best_model_has_been_loaded(
|
||||
tmpdir, 64 // self.batch_size, total, trainer, "eval_accuracy", greater_is_better=True
|
||||
)
|
||||
|
||||
# Test this works with a non PreTrainedModel
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
learning_rate=0.1,
|
||||
eval_steps=5,
|
||||
evaluation_strategy="steps",
|
||||
load_best_model_at_end=True,
|
||||
)
|
||||
trainer.model = RegressionModel(a=1.5, b=2.5)
|
||||
self.assertFalse(trainer.args.greater_is_better)
|
||||
trainer.train()
|
||||
self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False)
|
||||
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss", is_pretrained=False)
|
||||
|
||||
@slow
|
||||
def test_trainer_eval_mrpc(self):
|
||||
MODEL_ID = "bert-base-cased-finetuned-mrpc"
|
||||
|
||||
Reference in New Issue
Block a user