[generation] Less verbose warnings by default (#38179)
* tmp commit (imports broken) * working version; update tests * remove line break * shorter msg * dola checks need num_beams=1; other minor PR comments * update early trainer failing on bad gen config * make fixup * test msg
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -22,6 +23,7 @@ from huggingface_hub import HfFolder, create_pull_request
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -55,7 +57,14 @@ from transformers.generation import (
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WatermarkLogitsProcessor,
|
||||
)
|
||||
from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test, torch_device
|
||||
from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
CaptureLogger,
|
||||
LoggingLevel,
|
||||
TemporaryHubRepo,
|
||||
is_staging_test,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
class GenerationConfigTest(unittest.TestCase):
|
||||
@@ -112,24 +121,6 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
# `.update()` returns a dictionary of unused kwargs
|
||||
self.assertEqual(unused_kwargs, {"foo": "bar"})
|
||||
|
||||
# TODO: @Arthur and/or @Joao
|
||||
# FAILED tests/generation/test_configuration_utils.py::GenerationConfigTest::test_initialize_new_kwargs - AttributeError: 'GenerationConfig' object has no attribute 'get_text_config'
|
||||
# See: https://app.circleci.com/pipelines/github/huggingface/transformers/104831/workflows/e5e61514-51b7-4c8c-bba7-3c4d2986956e/jobs/1394252
|
||||
@unittest.skip("failed with `'GenerationConfig' object has no attribute 'get_text_config'`")
|
||||
def test_initialize_new_kwargs(self):
|
||||
generation_config = GenerationConfig()
|
||||
generation_config.foo = "bar"
|
||||
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
# update_kwargs was used to update the config on valid attributes
|
||||
self.assertEqual(new_config.foo, "bar")
|
||||
|
||||
generation_config = GenerationConfig.from_model_config(new_config)
|
||||
assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config
|
||||
|
||||
def test_kwarg_init(self):
|
||||
"""Tests that we can overwrite attributes at `from_pretrained` time."""
|
||||
default_config = GenerationConfig()
|
||||
@@ -159,38 +150,39 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
"""
|
||||
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
|
||||
"""
|
||||
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
|
||||
|
||||
# A correct configuration will not throw any warning
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig()
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
# Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
|
||||
# parameters with `do_sample=False`). May be escalated to an error in the future.
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
GenerationConfig(do_sample=False, temperature=0.5)
|
||||
self.assertEqual(len(captured_warnings), 1)
|
||||
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(return_dict_in_generate=False, output_scores=True)
|
||||
self.assertEqual(len(captured_warnings), 1)
|
||||
self.assertNotEqual(len(captured_logs.out), 0)
|
||||
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) # store for later
|
||||
self.assertNotEqual(len(captured_logs.out), 0)
|
||||
|
||||
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
|
||||
# that is done by unsetting the parameter (i.e. setting it to None)
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
# BAD - 0.9 means it is still set, we should warn
|
||||
generation_config_bad_temperature.update(temperature=0.9)
|
||||
self.assertEqual(len(captured_warnings), 1)
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
self.assertNotEqual(len(captured_logs.out), 0)
|
||||
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
|
||||
generation_config_bad_temperature.update(temperature=1.0)
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
# OK - None means it is unset, nothing to warn about
|
||||
generation_config_bad_temperature.update(temperature=None)
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
# Impossible sets of constraints/parameters will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -206,9 +198,32 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
GenerationConfig(logits_processor="foo")
|
||||
|
||||
# Model-specific parameters will NOT raise an exception or a warning
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(foo="bar")
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
# By default we throw a short warning. However, we log with INFO level the details.
|
||||
# Default: we don't log the incorrect input values, only a short summary. We explain how to get more details.
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(do_sample=False, temperature=0.5)
|
||||
self.assertNotIn("0.5", captured_logs.out)
|
||||
self.assertTrue(len(captured_logs.out) < 150) # short log
|
||||
self.assertIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
|
||||
|
||||
# INFO level: we share the full deets
|
||||
with LoggingLevel(logging.INFO):
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(do_sample=False, temperature=0.5)
|
||||
self.assertIn("0.5", captured_logs.out)
|
||||
self.assertTrue(len(captured_logs.out) > 400) # long log
|
||||
self.assertNotIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
|
||||
|
||||
# Finally, we can set `strict=True` to raise an exception on what would otherwise be a warning.
|
||||
generation_config = GenerationConfig()
|
||||
generation_config.temperature = 0.5
|
||||
generation_config.do_sample = False
|
||||
with self.assertRaises(ValueError):
|
||||
generation_config.validate(strict=True)
|
||||
|
||||
def test_refuse_to_save(self):
|
||||
"""Tests that we refuse to save a generation config that fails validation."""
|
||||
@@ -221,6 +236,7 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
with self.assertRaises(ValueError) as exc:
|
||||
config.save_pretrained(tmp_dir)
|
||||
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
|
||||
self.assertTrue("`temperature` is set to `0.5`" in str(exc.exception))
|
||||
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
|
||||
|
||||
# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
|
||||
@@ -231,15 +247,24 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
with self.assertRaises(ValueError) as exc:
|
||||
config.save_pretrained(tmp_dir)
|
||||
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
|
||||
self.assertTrue(
|
||||
"Greedy methods without beam search do not support `num_return_sequences` different than 1"
|
||||
in str(exc.exception)
|
||||
)
|
||||
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
|
||||
|
||||
# final check: no warnings/exceptions thrown if it is correct, and file is saved
|
||||
# Final check: no logs at warning level/warnings/exceptions thrown if it is correct, and file is saved.
|
||||
config = GenerationConfig()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Catch warnings
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
config.save_pretrained(tmp_dir)
|
||||
# Catch logs (up to WARNING level, the default level)
|
||||
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
config.save_pretrained(tmp_dir)
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
self.assertTrue(len(os.listdir(tmp_dir)) == 1)
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
self.assertEqual(len(os.listdir(tmp_dir)), 1)
|
||||
|
||||
def test_generation_mode(self):
|
||||
"""Tests that the `get_generation_mode` method is working as expected."""
|
||||
|
||||
@@ -202,4 +202,4 @@ class Seq2seqTrainerTester(TestCasePlus):
|
||||
data_collator=data_collator,
|
||||
compute_metrics=lambda x: {"samples": x[0].shape[0]},
|
||||
)
|
||||
self.assertIn("The loaded generation config instance is invalid", str(exc.exception))
|
||||
self.assertIn("Fix these issues to train your model", str(exc.exception))
|
||||
|
||||
Reference in New Issue
Block a user