[FIX] TextGenerationPipeline is currently broken. (#8256)
* [FIX] TextGenerationPipeline is currently broken. It's most likely due to #8180. What's missing is a multi vs single string handler at the beginning of the pipe. And also there was no testing of this pipeline. * Fixing Conversational tests too.
This commit is contained in:
@@ -836,6 +836,8 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
-- The token ids of the generated text.
|
-- The token ids of the generated text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if isinstance(text_inputs, str):
|
||||||
|
text_inputs = [text_inputs]
|
||||||
results = []
|
results = []
|
||||||
for prompt_text in text_inputs:
|
for prompt_text in text_inputs:
|
||||||
# Manage correct placement of the tensors
|
# Manage correct placement of the tensors
|
||||||
@@ -2382,6 +2384,8 @@ class ConversationalPipeline(Pipeline):
|
|||||||
updated generated responses for those containing a new user input.
|
updated generated responses for those containing a new user input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if isinstance(conversations, Conversation):
|
||||||
|
conversations = [conversations]
|
||||||
# Input validation
|
# Input validation
|
||||||
if isinstance(conversations, list):
|
if isinstance(conversations, list):
|
||||||
for conversation in conversations:
|
for conversation in conversations:
|
||||||
@@ -2398,8 +2402,6 @@ class ConversationalPipeline(Pipeline):
|
|||||||
assert (
|
assert (
|
||||||
self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
|
self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
|
||||||
), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
|
), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
|
||||||
elif isinstance(conversations, Conversation):
|
|
||||||
conversations = [conversations]
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input")
|
raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input")
|
||||||
|
|
||||||
|
|||||||
@@ -9,26 +9,30 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin
|
|||||||
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||||
pipeline_task = "conversational"
|
pipeline_task = "conversational"
|
||||||
small_models = [] # Models tested without the @slow decorator
|
small_models = [] # Models tested without the @slow decorator
|
||||||
large_models = ["microsoft/DialoGPT-medium"] # Models tested with the @slow decorator
|
large_models = ["microsoft/DialoGPT-medium"] # Models tested with the @slow decorator
|
||||||
valid_inputs = [Conversation("Hi there!"), [Conversation("Hi there!"), Conversation("How are you?")]]
|
|
||||||
invalid_inputs = ["Hi there!", Conversation()]
|
invalid_inputs = ["Hi there!", Conversation()]
|
||||||
|
|
||||||
def _test_pipeline(
|
def _test_pipeline(self, nlp):
|
||||||
self, nlp
|
# e overide the default test method to check that the output is a `Conversation` object
|
||||||
): # e overide the default test method to check that the output is a `Conversation` object
|
|
||||||
self.assertIsNotNone(nlp)
|
self.assertIsNotNone(nlp)
|
||||||
|
|
||||||
mono_result = nlp(self.valid_inputs[0])
|
# We need to recreate conversation for successive tests to pass as
|
||||||
|
# Conversation objects get *consumed* by the pipeline
|
||||||
|
conversation = Conversation("Hi there!")
|
||||||
|
mono_result = nlp(conversation)
|
||||||
self.assertIsInstance(mono_result, Conversation)
|
self.assertIsInstance(mono_result, Conversation)
|
||||||
|
|
||||||
multi_result = nlp(self.valid_inputs[1])
|
conversations = [Conversation("Hi there!"), Conversation("How are you?")]
|
||||||
|
multi_result = nlp(conversations)
|
||||||
self.assertIsInstance(multi_result, list)
|
self.assertIsInstance(multi_result, list)
|
||||||
self.assertIsInstance(multi_result[0], Conversation)
|
self.assertIsInstance(multi_result[0], Conversation)
|
||||||
|
# Conversation have been consumed and are not valid anymore
|
||||||
# Inactive conversations passed to the pipeline raise a ValueError
|
# Inactive conversations passed to the pipeline raise a ValueError
|
||||||
self.assertRaises(ValueError, nlp, self.valid_inputs[1])
|
self.assertRaises(ValueError, nlp, conversation)
|
||||||
|
self.assertRaises(ValueError, nlp, conversations)
|
||||||
|
|
||||||
for bad_input in self.invalid_inputs:
|
for bad_input in self.invalid_inputs:
|
||||||
self.assertRaises(Exception, nlp, bad_input)
|
self.assertRaises(Exception, nlp, bad_input)
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||||
|
|
||||||
|
|
||||||
@@ -8,3 +10,20 @@ class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
|||||||
pipeline_running_kwargs = {"prefix": "This is "}
|
pipeline_running_kwargs = {"prefix": "This is "}
|
||||||
small_models = ["sshleifer/tiny-ctrl"] # Models tested without the @slow decorator
|
small_models = ["sshleifer/tiny-ctrl"] # Models tested without the @slow decorator
|
||||||
large_models = [] # Models tested with the @slow decorator
|
large_models = [] # Models tested with the @slow decorator
|
||||||
|
|
||||||
|
def test_simple_generation(self):
|
||||||
|
nlp = pipeline(task="text-generation", model=self.small_models[0])
|
||||||
|
# text-generation is non-deterministic by nature, we can't fully test the output
|
||||||
|
|
||||||
|
outputs = nlp("This is a test")
|
||||||
|
|
||||||
|
self.assertEqual(len(outputs), 1)
|
||||||
|
self.assertEqual(list(outputs[0].keys()), ["generated_text"])
|
||||||
|
self.assertEqual(type(outputs[0]["generated_text"]), str)
|
||||||
|
|
||||||
|
outputs = nlp(["This is a test", "This is a second test"])
|
||||||
|
self.assertEqual(len(outputs[0]), 1)
|
||||||
|
self.assertEqual(list(outputs[0][0].keys()), ["generated_text"])
|
||||||
|
self.assertEqual(type(outputs[0][0]["generated_text"]), str)
|
||||||
|
self.assertEqual(list(outputs[1][0].keys()), ["generated_text"])
|
||||||
|
self.assertEqual(type(outputs[1][0]["generated_text"]), str)
|
||||||
|
|||||||
Reference in New Issue
Block a user