Adding a new return_full_text parameter to TextGenerationPipeline. (#9852)
* Adding a new `return_full_text` parameter to TextGenerationPipeline. For text-generation, it's sometimes used as prompting text. In that context, prefixing `generated_text` with the actual input forces the caller to take an extra step to remove it. The proposed change adds a new parameter (for backward compatibility). `return_full_text` that enables the caller to prevent adding the prefix. * Doc quality.
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||
|
||||
@@ -41,3 +42,21 @@ class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
self.assertEqual(type(outputs[0][0]["generated_text"]), str)
|
||||
self.assertEqual(list(outputs[1][0].keys()), ["generated_text"])
|
||||
self.assertEqual(type(outputs[1][0]["generated_text"]), str)
|
||||
|
||||
@require_torch
|
||||
def test_generation_output_style(self):
|
||||
text_generator = pipeline(task="text-generation", model=self.small_models[0])
|
||||
# text-generation is non-deterministic by nature, we can't fully test the output
|
||||
|
||||
outputs = text_generator("This is a test")
|
||||
self.assertIn("This is a test", outputs[0]["generated_text"])
|
||||
|
||||
outputs = text_generator("This is a test", return_full_text=False)
|
||||
self.assertNotIn("This is a test", outputs[0]["generated_text"])
|
||||
|
||||
text_generator = pipeline(task="text-generation", model=self.small_models[0], return_full_text=False)
|
||||
outputs = text_generator("This is a test")
|
||||
self.assertNotIn("This is a test", outputs[0]["generated_text"])
|
||||
|
||||
outputs = text_generator("This is a test", return_full_text=True)
|
||||
self.assertIn("This is a test", outputs[0]["generated_text"])
|
||||
|
||||
Reference in New Issue
Block a user