Adding a new return_full_text parameter to TextGenerationPipeline. (#9852)
* Adding a new `return_full_text` parameter to TextGenerationPipeline. For text-generation, it's sometimes used as prompting text. In that context, prefixing `generated_text` with the actual input forces the caller to take an extra step to remove it. The proposed change adds a new parameter (for backward compatibility). `return_full_text` that enables the caller to prevent adding the prefix. * Doc quality.
This commit is contained in:
@@ -44,10 +44,11 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
"TFCTRLLMHeadModel",
|
"TFCTRLLMHeadModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, return_full_text=True, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.check_model_type(self.ALLOWED_MODELS)
|
self.check_model_type(self.ALLOWED_MODELS)
|
||||||
|
self.return_full_text = return_full_text
|
||||||
|
|
||||||
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
|
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
|
||||||
def _parse_and_tokenize(self, *args, **kwargs):
|
def _parse_and_tokenize(self, *args, **kwargs):
|
||||||
@@ -65,6 +66,7 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
text_inputs,
|
text_inputs,
|
||||||
return_tensors=False,
|
return_tensors=False,
|
||||||
return_text=True,
|
return_text=True,
|
||||||
|
return_full_text=None,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
prefix=None,
|
prefix=None,
|
||||||
**generate_kwargs
|
**generate_kwargs
|
||||||
@@ -79,6 +81,9 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
Whether or not to include the tensors of predictions (as token indices) in the outputs.
|
Whether or not to include the tensors of predictions (as token indices) in the outputs.
|
||||||
return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether or not to include the decoded texts in the outputs.
|
Whether or not to include the decoded texts in the outputs.
|
||||||
|
return_full_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
If set to :obj:`False` only added text is returned, otherwise the full text is returned Only meaningful
|
||||||
|
if `return_text` is set to True.
|
||||||
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to clean up the potential extra spaces in the text output.
|
Whether or not to clean up the potential extra spaces in the text output.
|
||||||
prefix (:obj:`str`, `optional`):
|
prefix (:obj:`str`, `optional`):
|
||||||
@@ -94,6 +99,8 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
- **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
|
- **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
|
||||||
-- The token ids of the generated text.
|
-- The token ids of the generated text.
|
||||||
"""
|
"""
|
||||||
|
prefix = prefix if prefix is not None else self.model.config.prefix
|
||||||
|
return_full_text = return_full_text if return_full_text is not None else self.return_full_text
|
||||||
|
|
||||||
if isinstance(text_inputs, str):
|
if isinstance(text_inputs, str):
|
||||||
text_inputs = [text_inputs]
|
text_inputs = [text_inputs]
|
||||||
@@ -101,7 +108,6 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
for prompt_text in text_inputs:
|
for prompt_text in text_inputs:
|
||||||
# Manage correct placement of the tensors
|
# Manage correct placement of the tensors
|
||||||
with self.device_placement():
|
with self.device_placement():
|
||||||
prefix = prefix if prefix is not None else self.model.config.prefix
|
|
||||||
if prefix is None and self.model.__class__.__name__ in [
|
if prefix is None and self.model.__class__.__name__ in [
|
||||||
"XLNetLMHeadModel",
|
"XLNetLMHeadModel",
|
||||||
"TransfoXLLMHeadModel",
|
"TransfoXLLMHeadModel",
|
||||||
@@ -168,7 +174,12 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
record["generated_text"] = prompt_text + text[prompt_length:]
|
if return_full_text:
|
||||||
|
all_text = prompt_text + text[prompt_length:]
|
||||||
|
else:
|
||||||
|
all_text = text[prompt_length:]
|
||||||
|
|
||||||
|
record["generated_text"] = all_text
|
||||||
|
|
||||||
result.append(record)
|
result.append(record)
|
||||||
results += [result]
|
results += [result]
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
|
from transformers.testing_utils import require_torch
|
||||||
|
|
||||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||||
|
|
||||||
@@ -41,3 +42,21 @@ class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
|||||||
self.assertEqual(type(outputs[0][0]["generated_text"]), str)
|
self.assertEqual(type(outputs[0][0]["generated_text"]), str)
|
||||||
self.assertEqual(list(outputs[1][0].keys()), ["generated_text"])
|
self.assertEqual(list(outputs[1][0].keys()), ["generated_text"])
|
||||||
self.assertEqual(type(outputs[1][0]["generated_text"]), str)
|
self.assertEqual(type(outputs[1][0]["generated_text"]), str)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_generation_output_style(self):
|
||||||
|
text_generator = pipeline(task="text-generation", model=self.small_models[0])
|
||||||
|
# text-generation is non-deterministic by nature, we can't fully test the output
|
||||||
|
|
||||||
|
outputs = text_generator("This is a test")
|
||||||
|
self.assertIn("This is a test", outputs[0]["generated_text"])
|
||||||
|
|
||||||
|
outputs = text_generator("This is a test", return_full_text=False)
|
||||||
|
self.assertNotIn("This is a test", outputs[0]["generated_text"])
|
||||||
|
|
||||||
|
text_generator = pipeline(task="text-generation", model=self.small_models[0], return_full_text=False)
|
||||||
|
outputs = text_generator("This is a test")
|
||||||
|
self.assertNotIn("This is a test", outputs[0]["generated_text"])
|
||||||
|
|
||||||
|
outputs = text_generator("This is a test", return_full_text=True)
|
||||||
|
self.assertIn("This is a test", outputs[0]["generated_text"])
|
||||||
|
|||||||
Reference in New Issue
Block a user