Backward compatibility fix for the Conversation class (#27176)

* Backward compatibility fix for the Conversation class

* Explain what's going on in the conditional
This commit is contained in:
Matt
2023-10-31 15:12:06 +00:00
committed by GitHub
parent 309a90664f
commit 05f2290114
2 changed files with 21 additions and 9 deletions

View File

@@ -136,8 +136,8 @@ class ConversationalPipelineTests(unittest.TestCase):
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
conversation_2 = Conversation("What's the last book you have read?")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_2.past_user_inputs), 1)
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
# When
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
# Then
@@ -167,7 +167,7 @@ class ConversationalPipelineTests(unittest.TestCase):
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=torch_device)
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_1.past_user_inputs), 0)
# When
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
# Then
@@ -375,8 +375,8 @@ These are just a few of the many attractions that Paris has to offer. With so mu
conversation_1 = Conversation("My name is Sarah and I live in London")
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 1)
self.assertEqual(len(conversation_2.past_user_inputs), 1)
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
# When
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
# Then