Fix and re-enable ConversationalPipeline tests (#26907)
* Fix and re-enable conversationalpipeline tests * Fix the batch test so the change only applies to conversational pipeline
This commit is contained in:
@@ -77,14 +77,14 @@ class ConversationalPipelineTests(unittest.TestCase):
|
|||||||
|
|
||||||
def run_pipeline_test(self, conversation_agent, _):
|
def run_pipeline_test(self, conversation_agent, _):
|
||||||
# Simple
|
# Simple
|
||||||
outputs = conversation_agent(Conversation("Hi there!"))
|
outputs = conversation_agent(Conversation("Hi there!"), max_new_tokens=20)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
outputs,
|
outputs,
|
||||||
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
|
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Single list
|
# Single list
|
||||||
outputs = conversation_agent([Conversation("Hi there!")])
|
outputs = conversation_agent([Conversation("Hi there!")], max_new_tokens=20)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
outputs,
|
outputs,
|
||||||
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
|
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
|
||||||
@@ -96,7 +96,7 @@ class ConversationalPipelineTests(unittest.TestCase):
|
|||||||
self.assertEqual(len(conversation_1), 1)
|
self.assertEqual(len(conversation_1), 1)
|
||||||
self.assertEqual(len(conversation_2), 1)
|
self.assertEqual(len(conversation_2), 1)
|
||||||
|
|
||||||
outputs = conversation_agent([conversation_1, conversation_2])
|
outputs = conversation_agent([conversation_1, conversation_2], max_new_tokens=20)
|
||||||
self.assertEqual(outputs, [conversation_1, conversation_2])
|
self.assertEqual(outputs, [conversation_1, conversation_2])
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
outputs,
|
outputs,
|
||||||
@@ -118,7 +118,7 @@ class ConversationalPipelineTests(unittest.TestCase):
|
|||||||
|
|
||||||
# One conversation with history
|
# One conversation with history
|
||||||
conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"})
|
conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"})
|
||||||
outputs = conversation_agent(conversation_2)
|
outputs = conversation_agent(conversation_2, max_new_tokens=20)
|
||||||
self.assertEqual(outputs, conversation_2)
|
self.assertEqual(outputs, conversation_2)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
outputs,
|
outputs,
|
||||||
|
|||||||
@@ -312,8 +312,12 @@ class PipelineTesterMixin:
|
|||||||
yield copy.deepcopy(random.choice(examples))
|
yield copy.deepcopy(random.choice(examples))
|
||||||
|
|
||||||
out = []
|
out = []
|
||||||
for item in pipeline(data(10), batch_size=4):
|
if task == "conversational":
|
||||||
out.append(item)
|
for item in pipeline(data(10), batch_size=4, max_new_tokens=20):
|
||||||
|
out.append(item)
|
||||||
|
else:
|
||||||
|
for item in pipeline(data(10), batch_size=4):
|
||||||
|
out.append(item)
|
||||||
self.assertEqual(len(out), 10)
|
self.assertEqual(len(out), 10)
|
||||||
|
|
||||||
run_batch_test(pipeline, examples)
|
run_batch_test(pipeline, examples)
|
||||||
@@ -327,7 +331,6 @@ class PipelineTesterMixin:
|
|||||||
self.run_task_tests(task="automatic-speech-recognition")
|
self.run_task_tests(task="automatic-speech-recognition")
|
||||||
|
|
||||||
@is_pipeline_test
|
@is_pipeline_test
|
||||||
@unittest.skip("Conversational tests are currently broken for several models, will fix ASAP - Matt")
|
|
||||||
def test_pipeline_conversational(self):
|
def test_pipeline_conversational(self):
|
||||||
self.run_task_tests(task="conversational")
|
self.run_task_tests(task="conversational")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user