Fix/best model checkpoint fix (#35885)

* Set best_model_checkpoint only when ckpt exists.

Rather than set it explicitly without checking if the checkpoint directory even exists as before, now we moved the setting logic inside of _save_checkpoint and are only setting it if it exists.

* Added best_global_step to TrainerState.

* Added tests for best_model_checkpoint.

* Fixed hard-coded values in test to prevent fail.

* Added helper func and removed hard-coded best_step.

* Added side effect patch generator for _eval.

* Added evaluate side effect func.

* Removed erroneous patching.

* Fixed minor bug.

* Applied Ruff.

* Fixed Ruff problem in make style.

* Used Trainer.set_initial_training_values.
This commit is contained in:
Sean (Seok-Won) Yi
2025-03-14 22:24:53 +09:00
committed by GitHub
parent 3bd1a0ddf1
commit 691d1b52c3
4 changed files with 231 additions and 6 deletions

View File

@@ -38,7 +38,7 @@ from dataclasses import MISSING, fields
from functools import wraps
from io import StringIO
from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
from typing import Callable, Dict, Generator, Iterable, Iterator, List, Optional, Union
from unittest import mock
from unittest.mock import patch
@@ -48,6 +48,7 @@ import urllib3
from huggingface_hub import delete_repo
from packaging import version
from transformers import Trainer
from transformers import logging as transformers_logging
from .integrations import (
@@ -1440,6 +1441,34 @@ def get_tests_dir(append_path=None):
return tests_dir
def get_steps_per_epoch(trainer: Trainer) -> int:
training_args = trainer.args
train_dataloader = trainer.get_train_dataloader()
initial_training_values = trainer.set_initial_training_values(
args=training_args,
dataloader=train_dataloader,
total_train_batch_size=training_args.per_device_train_batch_size,
)
steps_per_epoch = initial_training_values[1]
return steps_per_epoch
def evaluate_side_effect_factory(
side_effect_values: List[Dict[str, float]],
) -> Generator[Dict[str, float], None, None]:
"""
Function that returns side effects for the _evaluate method.
Used when we're unsure of exactly how many times _evaluate will be called.
"""
for side_effect_value in side_effect_values:
yield side_effect_value
while True:
yield side_effect_values[-1]
#
# Helper functions for dealing with testing text outputs
# The original code came from:

View File

@@ -3178,12 +3178,10 @@ class Trainer:
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
if operator(metric_value, self.state.best_metric):
run_dir = self._get_output_dir(trial=trial)
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
output_dir = os.path.join(run_dir, checkpoint_folder)
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH]:
self.state.best_global_step = self.state.global_step
is_new_best_metric = True
@@ -3204,6 +3202,13 @@ class Trainer:
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir, _internal_call=True)
if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH] and self.state.best_global_step:
best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}"
best_checkpoint_dir = os.path.join(run_dir, best_checkpoint_folder)
if os.path.exists(best_checkpoint_dir):
self.state.best_model_checkpoint = best_checkpoint_dir
if not self.args.save_only_model:
# Save optimizer and scheduler
self._save_optimizer_and_scheduler(output_dir)

View File

@@ -74,6 +74,9 @@ class TrainerState:
The list of logs done since the beginning of training.
best_metric (`float`, *optional*):
When tracking the best model, the value of the best metric encountered so far.
best_global_step (`int`, *optional*):
When tracking the best model, the step at which the best metric was encountered.
Used for setting `best_model_checkpoint`.
best_model_checkpoint (`str`, *optional*):
When tracking the best model, the value of the name of the checkpoint for the best model encountered so
far.
@@ -103,6 +106,7 @@ class TrainerState:
total_flos: float = 0
log_history: List[Dict[str, float]] = None
best_metric: Optional[float] = None
best_global_step: Optional[int] = None
best_model_checkpoint: Optional[str] = None
is_local_process_zero: bool = True
is_world_process_zero: bool = True

View File

@@ -62,8 +62,10 @@ from transformers.testing_utils import (
TemporaryHubRepo,
TestCasePlus,
backend_device_count,
evaluate_side_effect_factory,
execute_subprocess_async,
get_gpu_count,
get_steps_per_epoch,
get_tests_dir,
is_staging_test,
require_accelerate,
@@ -4710,6 +4712,191 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
)
self.assertTrue(trainer.args.metric_for_best_model == "loss")
def test_best_model_checkpoint_behavior(self):
# Case 1. Never evaluated, save_total_limit > 1 and save_steps == 1.
# Both best_metric and best_model_checkpoint should be None.
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
eval_strategy="steps",
save_strategy="steps",
save_steps=1,
metric_for_best_model="accuracy",
greater_is_better=True,
)
trainer.train()
assert trainer.state.best_metric is None
assert trainer.state.best_model_checkpoint is None
assert len(os.listdir(tmpdir)) == trainer.state.global_step
# Case 2. Never evaluated and save_total_limit == 1.
# Both best_metric and best_model_checkpoint should be None.
# Only the last checkpoint should remain.
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
eval_strategy="steps",
save_strategy="steps",
save_steps=1,
metric_for_best_model="accuracy",
greater_is_better=True,
save_total_limit=1,
)
trainer.train()
num_steps = trainer.state.global_step
assert trainer.state.best_metric is None
assert trainer.state.best_model_checkpoint is None
assert len(os.listdir(tmpdir)) == 1
ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{num_steps}")
assert os.path.isdir(ckpt)
assert os.listdir(tmpdir)[0] == f"{PREFIX_CHECKPOINT_DIR}-{num_steps}"
# Case 3. eval_strategy == save_strategy.
# best_model_checkpoint should be at epoch 1.
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
eval_strategy="epoch",
save_strategy="epoch",
metric_for_best_model="accuracy",
compute_metrics=AlmostAccuracy(),
greater_is_better=True,
load_best_model_at_end=False,
)
with patch.object(
trainer,
"_evaluate",
side_effect=evaluate_side_effect_factory(
[
{"eval_accuracy": 0.59},
{"eval_accuracy": 0.57},
{"eval_accuracy": 0.55},
]
),
):
trainer.train()
steps_per_epoch = get_steps_per_epoch(trainer)
assert trainer.state.best_metric == 0.59
assert trainer.state.best_global_step == steps_per_epoch
best_ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{trainer.state.best_global_step}")
assert trainer.state.best_model_checkpoint == best_ckpt
assert len(os.listdir(tmpdir)) == trainer.state.num_train_epochs
# Case 4. eval_strategy != save_strategy.
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
eval_strategy="epoch",
save_strategy="steps",
save_steps=1,
metric_for_best_model="accuracy",
compute_metrics=AlmostAccuracy(),
greater_is_better=True,
load_best_model_at_end=False,
)
with patch.object(
trainer,
"_evaluate",
side_effect=evaluate_side_effect_factory(
[
{"eval_accuracy": 0.59},
{"eval_accuracy": 0.57},
{"eval_accuracy": 0.55},
]
),
):
trainer.train()
steps_per_epoch = get_steps_per_epoch(trainer)
assert trainer.state.best_metric == 0.59
assert trainer.state.best_global_step == steps_per_epoch
best_ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{trainer.state.best_global_step}")
assert trainer.state.best_model_checkpoint == best_ckpt
assert len(os.listdir(tmpdir)) == trainer.state.global_step
# Case 5. Multiple checkpoints, save_total_limit == 1.
# Best metric is found at step 1 and that checkpoint should be saved.
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
eval_strategy="steps",
eval_steps=1,
save_strategy="steps",
save_steps=1,
metric_for_best_model="accuracy",
compute_metrics=AlmostAccuracy(),
greater_is_better=True,
save_total_limit=1,
)
with patch.object(
trainer,
"_evaluate",
side_effect=evaluate_side_effect_factory(
[
{"eval_accuracy": 0.90},
{"eval_accuracy": 0.80},
{"eval_accuracy": 0.70},
]
),
):
trainer.train()
assert trainer.state.best_metric == 0.90
assert trainer.state.best_global_step == 1
best_ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{trainer.state.best_global_step}")
assert trainer.state.best_model_checkpoint == best_ckpt
assert len(os.listdir(tmpdir)) == 1
# Case 6. Saving happens more often and eval/save mismatch.
# `best_model_checkpoint` should be None due to a step mismatch.
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
eval_strategy="steps",
eval_steps=3,
save_strategy="steps",
save_steps=2,
metric_for_best_model="accuracy",
compute_metrics=AlmostAccuracy(),
greater_is_better=True,
)
with patch.object(
trainer,
"_evaluate",
side_effect=evaluate_side_effect_factory(
[
{"eval_accuracy": 0.90},
{"eval_accuracy": 0.80},
{"eval_accuracy": 0.70},
]
),
):
trainer.train()
assert trainer.state.best_metric == 0.90
assert trainer.state.best_global_step == 3
assert trainer.state.best_model_checkpoint is None
assert len(os.listdir(tmpdir)) == trainer.state.global_step // 2
@require_torch
@is_staging_test