🚨🚨🚨 [pipelines] update defaults in pipelines that can generate (#38129)

* pipeline generation defaults

* add max_new_tokens=20 in test pipelines

* pop all kwargs that are used to parameterize generation config

* add class attr that tell us whether a pipeline calls generate

* tmp commit

* pt text gen pipeline tests passing

* remove failing tf tests

* fix text gen pipeline mixin test corner case

* update text_to_audio pipeline tests

* trigger tests

* a few more tests

* skips

* some more audio tests

* not slow

* broken

* lower severity of generation mode errors

* fix all asr pipeline tests

* nit

* skip

* image to text pipeline tests

* text2test pipeline

* last pipelines

* fix flaky

* PR comments

* handle generate attrs more carefully in models that cant generate

* same as above
This commit is contained in:
Joao Gante
2025-05-19 18:02:06 +01:00
committed by GitHub
parent 6f9da7649f
commit 9c500015c5
23 changed files with 361 additions and 438 deletions

View File

@@ -204,8 +204,9 @@ class GenerationConfigTest(unittest.TestCase):
# By default we throw a short warning. However, we log with INFO level the details.
# Default: we don't log the incorrect input values, only a short summary. We explain how to get more details.
with CaptureLogger(logger) as captured_logs:
GenerationConfig(do_sample=False, temperature=0.5)
with LoggingLevel(logging.WARNING):
with CaptureLogger(logger) as captured_logs:
GenerationConfig(do_sample=False, temperature=0.5)
self.assertNotIn("0.5", captured_logs.out)
self.assertTrue(len(captured_logs.out) < 150) # short log
self.assertIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
@@ -259,9 +260,10 @@ class GenerationConfigTest(unittest.TestCase):
# Catch warnings
with warnings.catch_warnings(record=True) as captured_warnings:
# Catch logs (up to WARNING level, the default level)
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
with CaptureLogger(logger) as captured_logs:
config.save_pretrained(tmp_dir)
with LoggingLevel(logging.WARNING):
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
with CaptureLogger(logger) as captured_logs:
config.save_pretrained(tmp_dir)
self.assertEqual(len(captured_warnings), 0)
self.assertEqual(len(captured_logs.out), 0)
self.assertEqual(len(os.listdir(tmp_dir)), 1)