[generation] Less verbose warnings by default (#38179)

* tmp commit (imports broken)

* working version; update tests

* remove line break

* shorter msg

* dola checks need num_beams=1; other minor PR comments

* update early trainer failing on bad gen config

* make fixup

* test msg
This commit is contained in:
Joao Gante
2025-05-19 11:03:37 +01:00
committed by GitHub
parent 656e2eab3f
commit dbb9813dff
4 changed files with 251 additions and 247 deletions

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import copy
import logging
import os
import tempfile
import unittest
@@ -22,6 +23,7 @@ from huggingface_hub import HfFolder, create_pull_request
from parameterized import parameterized
from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available
from transformers import logging as transformers_logging
if is_torch_available():
@@ -55,7 +57,14 @@ from transformers.generation import (
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
)
from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test, torch_device
from transformers.testing_utils import (
TOKEN,
CaptureLogger,
LoggingLevel,
TemporaryHubRepo,
is_staging_test,
torch_device,
)
class GenerationConfigTest(unittest.TestCase):
@@ -112,24 +121,6 @@ class GenerationConfigTest(unittest.TestCase):
# `.update()` returns a dictionary of unused kwargs
self.assertEqual(unused_kwargs, {"foo": "bar"})
# TODO: @Arthur and/or @Joao
# FAILED tests/generation/test_configuration_utils.py::GenerationConfigTest::test_initialize_new_kwargs - AttributeError: 'GenerationConfig' object has no attribute 'get_text_config'
# See: https://app.circleci.com/pipelines/github/huggingface/transformers/104831/workflows/e5e61514-51b7-4c8c-bba7-3c4d2986956e/jobs/1394252
@unittest.skip("failed with `'GenerationConfig' object has no attribute 'get_text_config'`")
def test_initialize_new_kwargs(self):
generation_config = GenerationConfig()
generation_config.foo = "bar"
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)
new_config = GenerationConfig.from_pretrained(tmp_dir)
# update_kwargs was used to update the config on valid attributes
self.assertEqual(new_config.foo, "bar")
generation_config = GenerationConfig.from_model_config(new_config)
assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config
def test_kwarg_init(self):
"""Tests that we can overwrite attributes at `from_pretrained` time."""
default_config = GenerationConfig()
@@ -159,38 +150,39 @@ class GenerationConfigTest(unittest.TestCase):
"""
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
"""
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
# A correct configuration will not throw any warning
with warnings.catch_warnings(record=True) as captured_warnings:
with CaptureLogger(logger) as captured_logs:
GenerationConfig()
self.assertEqual(len(captured_warnings), 0)
self.assertEqual(len(captured_logs.out), 0)
# Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# parameters with `do_sample=False`). May be escalated to an error in the future.
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(do_sample=False, temperature=0.5)
self.assertEqual(len(captured_warnings), 1)
with warnings.catch_warnings(record=True) as captured_warnings:
with CaptureLogger(logger) as captured_logs:
GenerationConfig(return_dict_in_generate=False, output_scores=True)
self.assertEqual(len(captured_warnings), 1)
self.assertNotEqual(len(captured_logs.out), 0)
with CaptureLogger(logger) as captured_logs:
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) # store for later
self.assertNotEqual(len(captured_logs.out), 0)
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
# that is done by unsetting the parameter (i.e. setting it to None)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
with CaptureLogger(logger) as captured_logs:
# BAD - 0.9 means it is still set, we should warn
generation_config_bad_temperature.update(temperature=0.9)
self.assertEqual(len(captured_warnings), 1)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
self.assertNotEqual(len(captured_logs.out), 0)
with CaptureLogger(logger) as captured_logs:
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
generation_config_bad_temperature.update(temperature=1.0)
self.assertEqual(len(captured_warnings), 0)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
self.assertEqual(len(captured_logs.out), 0)
with CaptureLogger(logger) as captured_logs:
# OK - None means it is unset, nothing to warn about
generation_config_bad_temperature.update(temperature=None)
self.assertEqual(len(captured_warnings), 0)
self.assertEqual(len(captured_logs.out), 0)
# Impossible sets of constraints/parameters will raise an exception
with self.assertRaises(ValueError):
@@ -206,9 +198,32 @@ class GenerationConfigTest(unittest.TestCase):
GenerationConfig(logits_processor="foo")
# Model-specific parameters will NOT raise an exception or a warning
with warnings.catch_warnings(record=True) as captured_warnings:
with CaptureLogger(logger) as captured_logs:
GenerationConfig(foo="bar")
self.assertEqual(len(captured_warnings), 0)
self.assertEqual(len(captured_logs.out), 0)
# By default we throw a short warning. However, we log with INFO level the details.
# Default: we don't log the incorrect input values, only a short summary. We explain how to get more details.
with CaptureLogger(logger) as captured_logs:
GenerationConfig(do_sample=False, temperature=0.5)
self.assertNotIn("0.5", captured_logs.out)
self.assertTrue(len(captured_logs.out) < 150) # short log
self.assertIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
# INFO level: we share the full deets
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as captured_logs:
GenerationConfig(do_sample=False, temperature=0.5)
self.assertIn("0.5", captured_logs.out)
self.assertTrue(len(captured_logs.out) > 400) # long log
self.assertNotIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
# Finally, we can set `strict=True` to raise an exception on what would otherwise be a warning.
generation_config = GenerationConfig()
generation_config.temperature = 0.5
generation_config.do_sample = False
with self.assertRaises(ValueError):
generation_config.validate(strict=True)
def test_refuse_to_save(self):
"""Tests that we refuse to save a generation config that fails validation."""
@@ -221,6 +236,7 @@ class GenerationConfigTest(unittest.TestCase):
with self.assertRaises(ValueError) as exc:
config.save_pretrained(tmp_dir)
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
self.assertTrue("`temperature` is set to `0.5`" in str(exc.exception))
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
@@ -231,15 +247,24 @@ class GenerationConfigTest(unittest.TestCase):
with self.assertRaises(ValueError) as exc:
config.save_pretrained(tmp_dir)
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
self.assertTrue(
"Greedy methods without beam search do not support `num_return_sequences` different than 1"
in str(exc.exception)
)
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
# final check: no warnings/exceptions thrown if it is correct, and file is saved
# Final check: no logs at warning level/warnings/exceptions thrown if it is correct, and file is saved.
config = GenerationConfig()
with tempfile.TemporaryDirectory() as tmp_dir:
# Catch warnings
with warnings.catch_warnings(record=True) as captured_warnings:
config.save_pretrained(tmp_dir)
# Catch logs (up to WARNING level, the default level)
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
with CaptureLogger(logger) as captured_logs:
config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 0)
self.assertTrue(len(os.listdir(tmp_dir)) == 1)
self.assertEqual(len(captured_logs.out), 0)
self.assertEqual(len(os.listdir(tmp_dir)), 1)
def test_generation_mode(self):
"""Tests that the `get_generation_mode` method is working as expected."""

View File

@@ -202,4 +202,4 @@ class Seq2seqTrainerTester(TestCasePlus):
data_collator=data_collator,
compute_metrics=lambda x: {"samples": x[0].shape[0]},
)
self.assertIn("The loaded generation config instance is invalid", str(exc.exception))
self.assertIn("Fix these issues to train your model", str(exc.exception))