Generate: text generation pipeline no longer emits max_length warning when it is not set (#23139)
This commit is contained in:
@@ -14,8 +14,15 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, pipeline
|
||||
from transformers import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
TextGenerationPipeline,
|
||||
logging,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_pipeline_test,
|
||||
require_accelerate,
|
||||
require_tf,
|
||||
@@ -323,3 +330,26 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16)
|
||||
pipe("This is a test", do_sample=True, top_p=0.5)
|
||||
|
||||
def test_pipeline_length_setting_warning(self):
|
||||
prompt = """Hello world"""
|
||||
text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2")
|
||||
if text_generator.model.framework == "tf":
|
||||
logger = logging.get_logger("transformers.generation.tf_utils")
|
||||
else:
|
||||
logger = logging.get_logger("transformers.generation.utils")
|
||||
logger_msg = "Both `max_new_tokens`" # The beggining of the message to be checked in this test
|
||||
|
||||
# Both are set by the user -> log warning
|
||||
with CaptureLogger(logger) as cl:
|
||||
_ = text_generator(prompt, max_length=10, max_new_tokens=1)
|
||||
self.assertIn(logger_msg, cl.out)
|
||||
|
||||
# The user only sets one -> no warning
|
||||
with CaptureLogger(logger) as cl:
|
||||
_ = text_generator(prompt, max_new_tokens=1)
|
||||
self.assertNotIn(logger_msg, cl.out)
|
||||
|
||||
with CaptureLogger(logger) as cl:
|
||||
_ = text_generator(prompt, max_length=10)
|
||||
self.assertNotIn(logger_msg, cl.out)
|
||||
|
||||
Reference in New Issue
Block a user