Fix/best model checkpoint fix (#35885)
* Set best_model_checkpoint only when ckpt exists. Rather than set it explicitly without checking if the checkpoint directory even exists as before, now we moved the setting logic inside of _save_checkpoint and are only setting it if it exists. * Added best_global_step to TrainerState. * Added tests for best_model_checkpoint. * Fixed hard-coded values in test to prevent fail. * Added helper func and removed hard-coded best_step. * Added side effect patch generator for _eval. * Added evaluate side effect func. * Removed erroneous patching. * Fixed minor bug. * Applied Ruff. * Fixed Ruff problem in make style. * Used Trainer.set_initial_training_values.
This commit is contained in:
committed by
GitHub
parent
3bd1a0ddf1
commit
691d1b52c3
@@ -38,7 +38,7 @@ from dataclasses import MISSING, fields
|
||||
from functools import wraps
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
|
||||
from typing import Callable, Dict, Generator, Iterable, Iterator, List, Optional, Union
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -48,6 +48,7 @@ import urllib3
|
||||
from huggingface_hub import delete_repo
|
||||
from packaging import version
|
||||
|
||||
from transformers import Trainer
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
from .integrations import (
|
||||
@@ -1440,6 +1441,34 @@ def get_tests_dir(append_path=None):
|
||||
return tests_dir
|
||||
|
||||
|
||||
def get_steps_per_epoch(trainer: Trainer) -> int:
|
||||
training_args = trainer.args
|
||||
train_dataloader = trainer.get_train_dataloader()
|
||||
|
||||
initial_training_values = trainer.set_initial_training_values(
|
||||
args=training_args,
|
||||
dataloader=train_dataloader,
|
||||
total_train_batch_size=training_args.per_device_train_batch_size,
|
||||
)
|
||||
steps_per_epoch = initial_training_values[1]
|
||||
|
||||
return steps_per_epoch
|
||||
|
||||
|
||||
def evaluate_side_effect_factory(
|
||||
side_effect_values: List[Dict[str, float]],
|
||||
) -> Generator[Dict[str, float], None, None]:
|
||||
"""
|
||||
Function that returns side effects for the _evaluate method.
|
||||
Used when we're unsure of exactly how many times _evaluate will be called.
|
||||
"""
|
||||
for side_effect_value in side_effect_values:
|
||||
yield side_effect_value
|
||||
|
||||
while True:
|
||||
yield side_effect_values[-1]
|
||||
|
||||
|
||||
#
|
||||
# Helper functions for dealing with testing text outputs
|
||||
# The original code came from:
|
||||
|
||||
@@ -3178,12 +3178,10 @@ class Trainer:
|
||||
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
|
||||
|
||||
if operator(metric_value, self.state.best_metric):
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
|
||||
self.state.best_metric = metric_value
|
||||
self.state.best_model_checkpoint = output_dir
|
||||
|
||||
if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH]:
|
||||
self.state.best_global_step = self.state.global_step
|
||||
|
||||
is_new_best_metric = True
|
||||
|
||||
@@ -3204,6 +3202,13 @@ class Trainer:
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
self.save_model(output_dir, _internal_call=True)
|
||||
|
||||
if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH] and self.state.best_global_step:
|
||||
best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}"
|
||||
best_checkpoint_dir = os.path.join(run_dir, best_checkpoint_folder)
|
||||
|
||||
if os.path.exists(best_checkpoint_dir):
|
||||
self.state.best_model_checkpoint = best_checkpoint_dir
|
||||
|
||||
if not self.args.save_only_model:
|
||||
# Save optimizer and scheduler
|
||||
self._save_optimizer_and_scheduler(output_dir)
|
||||
|
||||
@@ -74,6 +74,9 @@ class TrainerState:
|
||||
The list of logs done since the beginning of training.
|
||||
best_metric (`float`, *optional*):
|
||||
When tracking the best model, the value of the best metric encountered so far.
|
||||
best_global_step (`int`, *optional*):
|
||||
When tracking the best model, the step at which the best metric was encountered.
|
||||
Used for setting `best_model_checkpoint`.
|
||||
best_model_checkpoint (`str`, *optional*):
|
||||
When tracking the best model, the value of the name of the checkpoint for the best model encountered so
|
||||
far.
|
||||
@@ -103,6 +106,7 @@ class TrainerState:
|
||||
total_flos: float = 0
|
||||
log_history: List[Dict[str, float]] = None
|
||||
best_metric: Optional[float] = None
|
||||
best_global_step: Optional[int] = None
|
||||
best_model_checkpoint: Optional[str] = None
|
||||
is_local_process_zero: bool = True
|
||||
is_world_process_zero: bool = True
|
||||
|
||||
@@ -62,8 +62,10 @@ from transformers.testing_utils import (
|
||||
TemporaryHubRepo,
|
||||
TestCasePlus,
|
||||
backend_device_count,
|
||||
evaluate_side_effect_factory,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
get_steps_per_epoch,
|
||||
get_tests_dir,
|
||||
is_staging_test,
|
||||
require_accelerate,
|
||||
@@ -4710,6 +4712,191 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
)
|
||||
self.assertTrue(trainer.args.metric_for_best_model == "loss")
|
||||
|
||||
def test_best_model_checkpoint_behavior(self):
|
||||
# Case 1. Never evaluated, save_total_limit > 1 and save_steps == 1.
|
||||
# Both best_metric and best_model_checkpoint should be None.
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
eval_strategy="steps",
|
||||
save_strategy="steps",
|
||||
save_steps=1,
|
||||
metric_for_best_model="accuracy",
|
||||
greater_is_better=True,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
assert trainer.state.best_metric is None
|
||||
assert trainer.state.best_model_checkpoint is None
|
||||
assert len(os.listdir(tmpdir)) == trainer.state.global_step
|
||||
|
||||
# Case 2. Never evaluated and save_total_limit == 1.
|
||||
# Both best_metric and best_model_checkpoint should be None.
|
||||
# Only the last checkpoint should remain.
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
eval_strategy="steps",
|
||||
save_strategy="steps",
|
||||
save_steps=1,
|
||||
metric_for_best_model="accuracy",
|
||||
greater_is_better=True,
|
||||
save_total_limit=1,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
num_steps = trainer.state.global_step
|
||||
|
||||
assert trainer.state.best_metric is None
|
||||
assert trainer.state.best_model_checkpoint is None
|
||||
assert len(os.listdir(tmpdir)) == 1
|
||||
|
||||
ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{num_steps}")
|
||||
assert os.path.isdir(ckpt)
|
||||
assert os.listdir(tmpdir)[0] == f"{PREFIX_CHECKPOINT_DIR}-{num_steps}"
|
||||
|
||||
# Case 3. eval_strategy == save_strategy.
|
||||
# best_model_checkpoint should be at epoch 1.
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
metric_for_best_model="accuracy",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
greater_is_better=True,
|
||||
load_best_model_at_end=False,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
trainer,
|
||||
"_evaluate",
|
||||
side_effect=evaluate_side_effect_factory(
|
||||
[
|
||||
{"eval_accuracy": 0.59},
|
||||
{"eval_accuracy": 0.57},
|
||||
{"eval_accuracy": 0.55},
|
||||
]
|
||||
),
|
||||
):
|
||||
trainer.train()
|
||||
|
||||
steps_per_epoch = get_steps_per_epoch(trainer)
|
||||
|
||||
assert trainer.state.best_metric == 0.59
|
||||
assert trainer.state.best_global_step == steps_per_epoch
|
||||
|
||||
best_ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{trainer.state.best_global_step}")
|
||||
assert trainer.state.best_model_checkpoint == best_ckpt
|
||||
|
||||
assert len(os.listdir(tmpdir)) == trainer.state.num_train_epochs
|
||||
|
||||
# Case 4. eval_strategy != save_strategy.
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="steps",
|
||||
save_steps=1,
|
||||
metric_for_best_model="accuracy",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
greater_is_better=True,
|
||||
load_best_model_at_end=False,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
trainer,
|
||||
"_evaluate",
|
||||
side_effect=evaluate_side_effect_factory(
|
||||
[
|
||||
{"eval_accuracy": 0.59},
|
||||
{"eval_accuracy": 0.57},
|
||||
{"eval_accuracy": 0.55},
|
||||
]
|
||||
),
|
||||
):
|
||||
trainer.train()
|
||||
|
||||
steps_per_epoch = get_steps_per_epoch(trainer)
|
||||
|
||||
assert trainer.state.best_metric == 0.59
|
||||
assert trainer.state.best_global_step == steps_per_epoch
|
||||
|
||||
best_ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{trainer.state.best_global_step}")
|
||||
assert trainer.state.best_model_checkpoint == best_ckpt
|
||||
|
||||
assert len(os.listdir(tmpdir)) == trainer.state.global_step
|
||||
|
||||
# Case 5. Multiple checkpoints, save_total_limit == 1.
|
||||
# Best metric is found at step 1 and that checkpoint should be saved.
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
eval_strategy="steps",
|
||||
eval_steps=1,
|
||||
save_strategy="steps",
|
||||
save_steps=1,
|
||||
metric_for_best_model="accuracy",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
greater_is_better=True,
|
||||
save_total_limit=1,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
trainer,
|
||||
"_evaluate",
|
||||
side_effect=evaluate_side_effect_factory(
|
||||
[
|
||||
{"eval_accuracy": 0.90},
|
||||
{"eval_accuracy": 0.80},
|
||||
{"eval_accuracy": 0.70},
|
||||
]
|
||||
),
|
||||
):
|
||||
trainer.train()
|
||||
|
||||
assert trainer.state.best_metric == 0.90
|
||||
assert trainer.state.best_global_step == 1
|
||||
|
||||
best_ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{trainer.state.best_global_step}")
|
||||
assert trainer.state.best_model_checkpoint == best_ckpt
|
||||
|
||||
assert len(os.listdir(tmpdir)) == 1
|
||||
|
||||
# Case 6. Saving happens more often and eval/save mismatch.
|
||||
# `best_model_checkpoint` should be None due to a step mismatch.
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
eval_strategy="steps",
|
||||
eval_steps=3,
|
||||
save_strategy="steps",
|
||||
save_steps=2,
|
||||
metric_for_best_model="accuracy",
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
greater_is_better=True,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
trainer,
|
||||
"_evaluate",
|
||||
side_effect=evaluate_side_effect_factory(
|
||||
[
|
||||
{"eval_accuracy": 0.90},
|
||||
{"eval_accuracy": 0.80},
|
||||
{"eval_accuracy": 0.70},
|
||||
]
|
||||
),
|
||||
):
|
||||
trainer.train()
|
||||
|
||||
assert trainer.state.best_metric == 0.90
|
||||
assert trainer.state.best_global_step == 3
|
||||
|
||||
assert trainer.state.best_model_checkpoint is None
|
||||
|
||||
assert len(os.listdir(tmpdir)) == trainer.state.global_step // 2
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user