Backward compatibility fix for the Conversation class (#27176)
* Backward compatibility fix for the Conversation class * Explain what's going on in the conditional
This commit is contained in:
@@ -54,6 +54,7 @@ class Conversation:
|
|||||||
|
|
||||||
# This block deals with the legacy args - new code should just totally
|
# This block deals with the legacy args - new code should just totally
|
||||||
# avoid past_user_inputs and generated_responses
|
# avoid past_user_inputs and generated_responses
|
||||||
|
self._num_processed_user_inputs = 0
|
||||||
generated_responses = deprecated_kwargs.pop("generated_responses", None)
|
generated_responses = deprecated_kwargs.pop("generated_responses", None)
|
||||||
past_user_inputs = deprecated_kwargs.pop("past_user_inputs", None)
|
past_user_inputs = deprecated_kwargs.pop("past_user_inputs", None)
|
||||||
if generated_responses is not None and past_user_inputs is None:
|
if generated_responses is not None and past_user_inputs is None:
|
||||||
@@ -114,10 +115,11 @@ class Conversation:
|
|||||||
|
|
||||||
def mark_processed(self):
|
def mark_processed(self):
|
||||||
"""
|
"""
|
||||||
This is a legacy method that no longer has any effect, as the Conversation no longer distinguishes between
|
This is a legacy method, as the Conversation no longer distinguishes between processed and unprocessed user
|
||||||
processed and unprocessed user input.
|
input. We set a counter here to keep behaviour mostly backward-compatible, but in general you should just read
|
||||||
|
the messages directly when writing new code.
|
||||||
"""
|
"""
|
||||||
pass
|
self._num_processed_user_inputs = len(self._user_messages)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
for message in self.messages:
|
for message in self.messages:
|
||||||
@@ -163,7 +165,17 @@ class Conversation:
|
|||||||
@property
|
@property
|
||||||
def past_user_inputs(self):
|
def past_user_inputs(self):
|
||||||
# This is a legacy property for backwards compatibility. It is recommended to just directly access
|
# This is a legacy property for backwards compatibility. It is recommended to just directly access
|
||||||
# conversation.messages instead.
|
# conversation.messages instead. The modern class does not care about which messages are "processed"
|
||||||
|
# or not.
|
||||||
|
if not self._user_messages:
|
||||||
|
return []
|
||||||
|
# In the past, the most recent user message had to be mark_processed() before being included
|
||||||
|
# in past_user_messages. The class essentially had a single-message buffer, representing messages that
|
||||||
|
# had not yet been replied to. This is no longer the case, but we mimic the behaviour in this property
|
||||||
|
# for backward compatibility.
|
||||||
|
if self.messages[-1]["role"] != "user" or self._num_processed_user_inputs == len(self._user_messages):
|
||||||
|
return self._user_messages
|
||||||
|
|
||||||
return self._user_messages[:-1]
|
return self._user_messages[:-1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -136,8 +136,8 @@ class ConversationalPipelineTests(unittest.TestCase):
|
|||||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||||
conversation_2 = Conversation("What's the last book you have read?")
|
conversation_2 = Conversation("What's the last book you have read?")
|
||||||
# Then
|
# Then
|
||||||
self.assertEqual(len(conversation_1.past_user_inputs), 1)
|
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||||
self.assertEqual(len(conversation_2.past_user_inputs), 1)
|
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||||
# When
|
# When
|
||||||
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||||
# Then
|
# Then
|
||||||
@@ -167,7 +167,7 @@ class ConversationalPipelineTests(unittest.TestCase):
|
|||||||
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=torch_device)
|
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=torch_device)
|
||||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||||
# Then
|
# Then
|
||||||
self.assertEqual(len(conversation_1.past_user_inputs), 1)
|
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||||
# When
|
# When
|
||||||
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
|
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
|
||||||
# Then
|
# Then
|
||||||
@@ -375,8 +375,8 @@ These are just a few of the many attractions that Paris has to offer. With so mu
|
|||||||
conversation_1 = Conversation("My name is Sarah and I live in London")
|
conversation_1 = Conversation("My name is Sarah and I live in London")
|
||||||
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
|
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
|
||||||
# Then
|
# Then
|
||||||
self.assertEqual(len(conversation_1.past_user_inputs), 1)
|
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||||
self.assertEqual(len(conversation_2.past_user_inputs), 1)
|
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||||
# When
|
# When
|
||||||
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||||
# Then
|
# Then
|
||||||
|
|||||||
Reference in New Issue
Block a user