From 691d1b52c33b0bda5187dc5afad8a30633003bb2 Mon Sep 17 00:00:00 2001 From: "Sean (Seok-Won) Yi" Date: Fri, 14 Mar 2025 22:24:53 +0900 Subject: [PATCH] Fix/best model checkpoint fix (#35885) * Set best_model_checkpoint only when ckpt exists. Rather than set it explicitly without checking if the checkpoint directory even exists as before, now we moved the setting logic inside of _save_checkpoint and are only setting it if it exists. * Added best_global_step to TrainerState. * Added tests for best_model_checkpoint. * Fixed hard-coded values in test to prevent fail. * Added helper func and removed hard-coded best_step. * Added side effect patch generator for _eval. * Added evaluate side effect func. * Removed erroneous patching. * Fixed minor bug. * Applied Ruff. * Fixed Ruff problem in make style. * Used Trainer.set_initial_training_values. --- src/transformers/testing_utils.py | 31 ++++- src/transformers/trainer.py | 15 ++- src/transformers/trainer_callback.py | 4 + tests/trainer/test_trainer.py | 187 +++++++++++++++++++++++++++ 4 files changed, 231 insertions(+), 6 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 0b4307dff6..d75c2d778f 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -38,7 +38,7 @@ from dataclasses import MISSING, fields from functools import wraps from io import StringIO from pathlib import Path -from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union +from typing import Callable, Dict, Generator, Iterable, Iterator, List, Optional, Union from unittest import mock from unittest.mock import patch @@ -48,6 +48,7 @@ import urllib3 from huggingface_hub import delete_repo from packaging import version +from transformers import Trainer from transformers import logging as transformers_logging from .integrations import ( @@ -1440,6 +1441,34 @@ def get_tests_dir(append_path=None): return tests_dir +def get_steps_per_epoch(trainer: Trainer) -> int: + training_args = trainer.args + train_dataloader = trainer.get_train_dataloader() + + initial_training_values = trainer.set_initial_training_values( + args=training_args, + dataloader=train_dataloader, + total_train_batch_size=training_args.per_device_train_batch_size, + ) + steps_per_epoch = initial_training_values[1] + + return steps_per_epoch + + +def evaluate_side_effect_factory( + side_effect_values: List[Dict[str, float]], +) -> Generator[Dict[str, float], None, None]: + """ + Function that returns side effects for the _evaluate method. + Used when we're unsure of exactly how many times _evaluate will be called. + """ + for side_effect_value in side_effect_values: + yield side_effect_value + + while True: + yield side_effect_values[-1] + + # # Helper functions for dealing with testing text outputs # The original code came from: diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 274defeee1..bea32918b6 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3178,12 +3178,10 @@ class Trainer: self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf") if operator(metric_value, self.state.best_metric): - run_dir = self._get_output_dir(trial=trial) - checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" - output_dir = os.path.join(run_dir, checkpoint_folder) - self.state.best_metric = metric_value - self.state.best_model_checkpoint = output_dir + + if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH]: + self.state.best_global_step = self.state.global_step is_new_best_metric = True @@ -3204,6 +3202,13 @@ class Trainer: output_dir = os.path.join(run_dir, checkpoint_folder) self.save_model(output_dir, _internal_call=True) + if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH] and self.state.best_global_step: + best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}" + best_checkpoint_dir = os.path.join(run_dir, best_checkpoint_folder) + + if os.path.exists(best_checkpoint_dir): + self.state.best_model_checkpoint = best_checkpoint_dir + if not self.args.save_only_model: # Save optimizer and scheduler self._save_optimizer_and_scheduler(output_dir) diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 9a5eecd782..027fce086c 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -74,6 +74,9 @@ class TrainerState: The list of logs done since the beginning of training. best_metric (`float`, *optional*): When tracking the best model, the value of the best metric encountered so far. + best_global_step (`int`, *optional*): + When tracking the best model, the step at which the best metric was encountered. + Used for setting `best_model_checkpoint`. best_model_checkpoint (`str`, *optional*): When tracking the best model, the value of the name of the checkpoint for the best model encountered so far. @@ -103,6 +106,7 @@ class TrainerState: total_flos: float = 0 log_history: List[Dict[str, float]] = None best_metric: Optional[float] = None + best_global_step: Optional[int] = None best_model_checkpoint: Optional[str] = None is_local_process_zero: bool = True is_world_process_zero: bool = True diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c8d9f34ff5..c4c90d5dcb 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -62,8 +62,10 @@ from transformers.testing_utils import ( TemporaryHubRepo, TestCasePlus, backend_device_count, + evaluate_side_effect_factory, execute_subprocess_async, get_gpu_count, + get_steps_per_epoch, get_tests_dir, is_staging_test, require_accelerate, @@ -4710,6 +4712,191 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ) self.assertTrue(trainer.args.metric_for_best_model == "loss") + def test_best_model_checkpoint_behavior(self): + # Case 1. Never evaluated, save_total_limit > 1 and save_steps == 1. + # Both best_metric and best_model_checkpoint should be None. + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + output_dir=tmpdir, + eval_strategy="steps", + save_strategy="steps", + save_steps=1, + metric_for_best_model="accuracy", + greater_is_better=True, + ) + trainer.train() + + assert trainer.state.best_metric is None + assert trainer.state.best_model_checkpoint is None + assert len(os.listdir(tmpdir)) == trainer.state.global_step + + # Case 2. Never evaluated and save_total_limit == 1. + # Both best_metric and best_model_checkpoint should be None. + # Only the last checkpoint should remain. + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + output_dir=tmpdir, + eval_strategy="steps", + save_strategy="steps", + save_steps=1, + metric_for_best_model="accuracy", + greater_is_better=True, + save_total_limit=1, + ) + trainer.train() + + num_steps = trainer.state.global_step + + assert trainer.state.best_metric is None + assert trainer.state.best_model_checkpoint is None + assert len(os.listdir(tmpdir)) == 1 + + ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{num_steps}") + assert os.path.isdir(ckpt) + assert os.listdir(tmpdir)[0] == f"{PREFIX_CHECKPOINT_DIR}-{num_steps}" + + # Case 3. eval_strategy == save_strategy. + # best_model_checkpoint should be at epoch 1. + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + output_dir=tmpdir, + eval_strategy="epoch", + save_strategy="epoch", + metric_for_best_model="accuracy", + compute_metrics=AlmostAccuracy(), + greater_is_better=True, + load_best_model_at_end=False, + ) + + with patch.object( + trainer, + "_evaluate", + side_effect=evaluate_side_effect_factory( + [ + {"eval_accuracy": 0.59}, + {"eval_accuracy": 0.57}, + {"eval_accuracy": 0.55}, + ] + ), + ): + trainer.train() + + steps_per_epoch = get_steps_per_epoch(trainer) + + assert trainer.state.best_metric == 0.59 + assert trainer.state.best_global_step == steps_per_epoch + + best_ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{trainer.state.best_global_step}") + assert trainer.state.best_model_checkpoint == best_ckpt + + assert len(os.listdir(tmpdir)) == trainer.state.num_train_epochs + + # Case 4. eval_strategy != save_strategy. + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + output_dir=tmpdir, + eval_strategy="epoch", + save_strategy="steps", + save_steps=1, + metric_for_best_model="accuracy", + compute_metrics=AlmostAccuracy(), + greater_is_better=True, + load_best_model_at_end=False, + ) + + with patch.object( + trainer, + "_evaluate", + side_effect=evaluate_side_effect_factory( + [ + {"eval_accuracy": 0.59}, + {"eval_accuracy": 0.57}, + {"eval_accuracy": 0.55}, + ] + ), + ): + trainer.train() + + steps_per_epoch = get_steps_per_epoch(trainer) + + assert trainer.state.best_metric == 0.59 + assert trainer.state.best_global_step == steps_per_epoch + + best_ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{trainer.state.best_global_step}") + assert trainer.state.best_model_checkpoint == best_ckpt + + assert len(os.listdir(tmpdir)) == trainer.state.global_step + + # Case 5. Multiple checkpoints, save_total_limit == 1. + # Best metric is found at step 1 and that checkpoint should be saved. + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + output_dir=tmpdir, + eval_strategy="steps", + eval_steps=1, + save_strategy="steps", + save_steps=1, + metric_for_best_model="accuracy", + compute_metrics=AlmostAccuracy(), + greater_is_better=True, + save_total_limit=1, + ) + + with patch.object( + trainer, + "_evaluate", + side_effect=evaluate_side_effect_factory( + [ + {"eval_accuracy": 0.90}, + {"eval_accuracy": 0.80}, + {"eval_accuracy": 0.70}, + ] + ), + ): + trainer.train() + + assert trainer.state.best_metric == 0.90 + assert trainer.state.best_global_step == 1 + + best_ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{trainer.state.best_global_step}") + assert trainer.state.best_model_checkpoint == best_ckpt + + assert len(os.listdir(tmpdir)) == 1 + + # Case 6. Saving happens more often and eval/save mismatch. + # `best_model_checkpoint` should be None due to a step mismatch. + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + output_dir=tmpdir, + eval_strategy="steps", + eval_steps=3, + save_strategy="steps", + save_steps=2, + metric_for_best_model="accuracy", + compute_metrics=AlmostAccuracy(), + greater_is_better=True, + ) + + with patch.object( + trainer, + "_evaluate", + side_effect=evaluate_side_effect_factory( + [ + {"eval_accuracy": 0.90}, + {"eval_accuracy": 0.80}, + {"eval_accuracy": 0.70}, + ] + ), + ): + trainer.train() + + assert trainer.state.best_metric == 0.90 + assert trainer.state.best_global_step == 3 + + assert trainer.state.best_model_checkpoint is None + + assert len(os.listdir(tmpdir)) == trainer.state.global_step // 2 + @require_torch @is_staging_test