🚨🚨 TextGenerationPipeline: rely on the tokenizer default kwargs (#31747)
* rely on the tokenizer default kwargs * fix a few tests
This commit is contained in:
@@ -266,31 +266,33 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
prompt_text,
|
prompt_text,
|
||||||
prefix="",
|
prefix="",
|
||||||
handle_long_generation=None,
|
handle_long_generation=None,
|
||||||
add_special_tokens=False,
|
add_special_tokens=None,
|
||||||
truncation=None,
|
truncation=None,
|
||||||
padding=False,
|
padding=None,
|
||||||
max_length=None,
|
max_length=None,
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
):
|
):
|
||||||
if isinstance(prompt_text, Chat):
|
if isinstance(prompt_text, Chat):
|
||||||
|
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
|
||||||
|
tokenizer_kwargs = {}
|
||||||
|
for tokenizer_kwarg_name in ["truncation", "padding", "max_length"]:
|
||||||
|
if locals()[tokenizer_kwarg_name] is not None:
|
||||||
|
tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name]
|
||||||
inputs = self.tokenizer.apply_chat_template(
|
inputs = self.tokenizer.apply_chat_template(
|
||||||
prompt_text.messages,
|
prompt_text.messages,
|
||||||
truncation=truncation,
|
|
||||||
padding=padding,
|
|
||||||
max_length=max_length,
|
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
return_tensors=self.framework,
|
return_tensors=self.framework,
|
||||||
|
**tokenizer_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inputs = self.tokenizer(
|
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
|
||||||
prefix + prompt_text,
|
tokenizer_kwargs = {}
|
||||||
truncation=truncation,
|
for tokenizer_kwarg_name in ["add_special_tokens", "truncation", "padding", "max_length"]:
|
||||||
padding=padding,
|
if locals()[tokenizer_kwarg_name] is not None:
|
||||||
max_length=max_length,
|
tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name]
|
||||||
add_special_tokens=add_special_tokens,
|
inputs = self.tokenizer(prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs)
|
||||||
return_tensors=self.framework,
|
|
||||||
)
|
|
||||||
inputs["prompt_text"] = prompt_text
|
inputs["prompt_text"] = prompt_text
|
||||||
|
|
||||||
if handle_long_generation == "hole":
|
if handle_long_generation == "hole":
|
||||||
|
|||||||
@@ -2087,6 +2087,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
[1, 18],
|
[1, 18],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO (joao): replace `stop_sequence` in the pipeline by the more recent `generate` functionality
|
||||||
def test_stop_sequence_stopping_criteria(self):
|
def test_stop_sequence_stopping_criteria(self):
|
||||||
# PT-only test: TF doesn't have StoppingCriteria
|
# PT-only test: TF doesn't have StoppingCriteria
|
||||||
prompt = """Hello I believe in"""
|
prompt = """Hello I believe in"""
|
||||||
@@ -2094,17 +2095,11 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
output = generator(prompt)
|
output = generator(prompt)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
output,
|
output,
|
||||||
[
|
[{"generated_text": ("Hello I believe in we we we we we we we we we")}],
|
||||||
{
|
|
||||||
"generated_text": (
|
|
||||||
"Hello I believe in in in number number number number number number number number number"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
output = generator(prompt, stop_sequence=" number")
|
output = generator(prompt, stop_sequence=" we")
|
||||||
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
|
self.assertEqual(output, [{"generated_text": "Hello I believe in we"}])
|
||||||
|
|
||||||
def test_generate_non_nlp_input_ids_as_kwarg(self):
|
def test_generate_non_nlp_input_ids_as_kwarg(self):
|
||||||
# PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input
|
# PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input
|
||||||
|
|||||||
@@ -398,7 +398,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
self.assertEqual(outputs, [{"generated_text": ANY(str)}])
|
self.assertEqual(outputs, [{"generated_text": ANY(str)}])
|
||||||
else:
|
else:
|
||||||
with self.assertRaises((ValueError, AssertionError)):
|
with self.assertRaises((ValueError, AssertionError)):
|
||||||
outputs = text_generator("")
|
outputs = text_generator("", add_special_tokens=False)
|
||||||
|
|
||||||
if text_generator.framework == "tf":
|
if text_generator.framework == "tf":
|
||||||
# TF generation does not support max_new_tokens, and it's impossible
|
# TF generation does not support max_new_tokens, and it's impossible
|
||||||
|
|||||||
Reference in New Issue
Block a user