Adding a new return_full_text parameter to TextGenerationPipeline. (#9852)

* Adding a new `return_full_text` parameter to TextGenerationPipeline.

For text-generation, it's sometimes used as prompting text.
In that context, prefixing `generated_text` with the actual input
forces the caller to take an extra step to remove it.

The proposed change adds a new parameter (for backward compatibility).
`return_full_text` that enables the caller to prevent adding the prefix.

* Doc quality.
This commit is contained in:
Nicolas Patry
2021-01-29 10:27:32 +01:00
committed by GitHub
parent bc109ae5b8
commit c2d0ffec8c
2 changed files with 33 additions and 3 deletions

View File

@@ -15,6 +15,7 @@
import unittest
from transformers import pipeline
from transformers.testing_utils import require_torch
from .test_pipelines_common import MonoInputPipelineCommonMixin
@@ -41,3 +42,21 @@ class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
self.assertEqual(type(outputs[0][0]["generated_text"]), str)
self.assertEqual(list(outputs[1][0].keys()), ["generated_text"])
self.assertEqual(type(outputs[1][0]["generated_text"]), str)
@require_torch
def test_generation_output_style(self):
text_generator = pipeline(task="text-generation", model=self.small_models[0])
# text-generation is non-deterministic by nature, we can't fully test the output
outputs = text_generator("This is a test")
self.assertIn("This is a test", outputs[0]["generated_text"])
outputs = text_generator("This is a test", return_full_text=False)
self.assertNotIn("This is a test", outputs[0]["generated_text"])
text_generator = pipeline(task="text-generation", model=self.small_models[0], return_full_text=False)
outputs = text_generator("This is a test")
self.assertNotIn("This is a test", outputs[0]["generated_text"])
outputs = text_generator("This is a test", return_full_text=True)
self.assertIn("This is a test", outputs[0]["generated_text"])