From 82486e5995ed0a65520b10ce1ea938214a199231 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 2 Jul 2024 15:17:42 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=20TextGenerationPipeli?= =?UTF-8?q?ne:=20rely=20on=20the=20tokenizer=20default=20kwargs=20(#31747)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * rely on the tokenizer default kwargs * fix a few tests --- src/transformers/pipelines/text_generation.py | 28 ++++++++++--------- tests/generation/test_utils.py | 13 +++------ .../test_pipelines_text_generation.py | 2 +- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index c2dce89dd7..994a517485 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -266,31 +266,33 @@ class TextGenerationPipeline(Pipeline): prompt_text, prefix="", handle_long_generation=None, - add_special_tokens=False, + add_special_tokens=None, truncation=None, - padding=False, + padding=None, max_length=None, **generate_kwargs, ): if isinstance(prompt_text, Chat): + # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults + tokenizer_kwargs = {} + for tokenizer_kwarg_name in ["truncation", "padding", "max_length"]: + if locals()[tokenizer_kwarg_name] is not None: + tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name] inputs = self.tokenizer.apply_chat_template( prompt_text.messages, - truncation=truncation, - padding=padding, - max_length=max_length, add_generation_prompt=True, return_dict=True, return_tensors=self.framework, + **tokenizer_kwargs, ) else: - inputs = self.tokenizer( - prefix + prompt_text, - truncation=truncation, - padding=padding, - max_length=max_length, - add_special_tokens=add_special_tokens, - return_tensors=self.framework, - ) + # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults + tokenizer_kwargs = {} + for tokenizer_kwarg_name in ["add_special_tokens", "truncation", "padding", "max_length"]: + if locals()[tokenizer_kwarg_name] is not None: + tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name] + inputs = self.tokenizer(prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs) + inputs["prompt_text"] = prompt_text if handle_long_generation == "hole": diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 469bfa9206..b9e962a6a1 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2087,6 +2087,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi [1, 18], ) + # TODO (joao): replace `stop_sequence` in the pipeline by the more recent `generate` functionality def test_stop_sequence_stopping_criteria(self): # PT-only test: TF doesn't have StoppingCriteria prompt = """Hello I believe in""" @@ -2094,17 +2095,11 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi output = generator(prompt) self.assertEqual( output, - [ - { - "generated_text": ( - "Hello I believe in in in number number number number number number number number number" - ) - } - ], + [{"generated_text": ("Hello I believe in we we we we we we we we we")}], ) - output = generator(prompt, stop_sequence=" number") - self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}]) + output = generator(prompt, stop_sequence=" we") + self.assertEqual(output, [{"generated_text": "Hello I believe in we"}]) def test_generate_non_nlp_input_ids_as_kwarg(self): # PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 00ddd77f82..695befe329 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -398,7 +398,7 @@ class TextGenerationPipelineTests(unittest.TestCase): self.assertEqual(outputs, [{"generated_text": ANY(str)}]) else: with self.assertRaises((ValueError, AssertionError)): - outputs = text_generator("") + outputs = text_generator("", add_special_tokens=False) if text_generator.framework == "tf": # TF generation does not support max_new_tokens, and it's impossible