[FIX] TextGenerationPipeline is currently broken. (#8256)

* [FIX] TextGenerationPipeline is currently broken.

It's most likely due to #8180.
What's missing is a multi vs single string handler at the beginning of
the pipe.
And also there was no testing of this pipeline.

* Fixing Conversational tests too.
This commit is contained in:
Nicolas Patry
2020-11-03 16:10:22 +01:00
committed by GitHub
parent a1bbcf3f6c
commit c66ffa3a17
3 changed files with 35 additions and 10 deletions

View File

@@ -1,5 +1,7 @@
import unittest
from transformers import pipeline
from .test_pipelines_common import MonoInputPipelineCommonMixin
@@ -8,3 +10,20 @@ class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
pipeline_running_kwargs = {"prefix": "This is "}
small_models = ["sshleifer/tiny-ctrl"] # Models tested without the @slow decorator
large_models = [] # Models tested with the @slow decorator
def test_simple_generation(self):
nlp = pipeline(task="text-generation", model=self.small_models[0])
# text-generation is non-deterministic by nature, we can't fully test the output
outputs = nlp("This is a test")
self.assertEqual(len(outputs), 1)
self.assertEqual(list(outputs[0].keys()), ["generated_text"])
self.assertEqual(type(outputs[0]["generated_text"]), str)
outputs = nlp(["This is a test", "This is a second test"])
self.assertEqual(len(outputs[0]), 1)
self.assertEqual(list(outputs[0][0].keys()), ["generated_text"])
self.assertEqual(type(outputs[0][0]["generated_text"]), str)
self.assertEqual(list(outputs[1][0].keys()), ["generated_text"])
self.assertEqual(type(outputs[1][0]["generated_text"]), str)