[FIX] TextGenerationPipeline is currently broken. (#8256)
* [FIX] TextGenerationPipeline is currently broken. It's most likely due to #8180. What's missing is a multi vs single string handler at the beginning of the pipe. And also there was no testing of this pipeline. * Fixing Conversational tests too.
This commit is contained in:
@@ -9,26 +9,30 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
||||
|
||||
|
||||
class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "conversational"
|
||||
small_models = [] # Models tested without the @slow decorator
|
||||
large_models = ["microsoft/DialoGPT-medium"] # Models tested with the @slow decorator
|
||||
valid_inputs = [Conversation("Hi there!"), [Conversation("Hi there!"), Conversation("How are you?")]]
|
||||
invalid_inputs = ["Hi there!", Conversation()]
|
||||
|
||||
def _test_pipeline(
|
||||
self, nlp
|
||||
): # e overide the default test method to check that the output is a `Conversation` object
|
||||
def _test_pipeline(self, nlp):
|
||||
# e overide the default test method to check that the output is a `Conversation` object
|
||||
self.assertIsNotNone(nlp)
|
||||
|
||||
mono_result = nlp(self.valid_inputs[0])
|
||||
# We need to recreate conversation for successive tests to pass as
|
||||
# Conversation objects get *consumed* by the pipeline
|
||||
conversation = Conversation("Hi there!")
|
||||
mono_result = nlp(conversation)
|
||||
self.assertIsInstance(mono_result, Conversation)
|
||||
|
||||
multi_result = nlp(self.valid_inputs[1])
|
||||
conversations = [Conversation("Hi there!"), Conversation("How are you?")]
|
||||
multi_result = nlp(conversations)
|
||||
self.assertIsInstance(multi_result, list)
|
||||
self.assertIsInstance(multi_result[0], Conversation)
|
||||
# Conversation have been consumed and are not valid anymore
|
||||
# Inactive conversations passed to the pipeline raise a ValueError
|
||||
self.assertRaises(ValueError, nlp, self.valid_inputs[1])
|
||||
self.assertRaises(ValueError, nlp, conversation)
|
||||
self.assertRaises(ValueError, nlp, conversations)
|
||||
|
||||
for bad_input in self.invalid_inputs:
|
||||
self.assertRaises(Exception, nlp, bad_input)
|
||||
|
||||
Reference in New Issue
Block a user