Fix ConversationalPipeline tests (#26217)
Add BlenderbotSmall templates and correct handling for conversation.past_user_inputs
This commit is contained in:
@@ -236,3 +236,18 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
|
|||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
return vocab_file, merge_file
|
return vocab_file, merge_file
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from transformers.models.blenderbot.tokenization_blenderbot.BlenderbotTokenizer.default_chat_template
|
||||||
|
def default_chat_template(self):
|
||||||
|
"""
|
||||||
|
A very simple chat template that just adds whitespace between messages.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
"{% for message in messages %}"
|
||||||
|
"{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}"
|
||||||
|
"{{ message['content'] }}"
|
||||||
|
"{% if not loop.last %}{{ ' ' }}{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
"{{ eos_token }}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -117,3 +117,18 @@ class BlenderbotSmallTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
if token_ids_1 is None:
|
if token_ids_1 is None:
|
||||||
return len(cls + token_ids_0 + sep) * [0]
|
return len(cls + token_ids_0 + sep) * [0]
|
||||||
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from transformers.models.blenderbot.tokenization_blenderbot.BlenderbotTokenizer.default_chat_template
|
||||||
|
def default_chat_template(self):
|
||||||
|
"""
|
||||||
|
A very simple chat template that just adds whitespace between messages.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
"{% for message in messages %}"
|
||||||
|
"{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}"
|
||||||
|
"{{ message['content'] }}"
|
||||||
|
"{% if not loop.last %}{{ ' ' }}{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
"{{ eos_token }}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -140,8 +140,8 @@ class ConversationalPipelineTests(unittest.TestCase):
|
|||||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||||
conversation_2 = Conversation("What's the last book you have read?")
|
conversation_2 = Conversation("What's the last book you have read?")
|
||||||
# Then
|
# Then
|
||||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
self.assertEqual(len(conversation_1.past_user_inputs), 1)
|
||||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
self.assertEqual(len(conversation_2.past_user_inputs), 1)
|
||||||
# When
|
# When
|
||||||
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||||
# Then
|
# Then
|
||||||
@@ -171,7 +171,7 @@ class ConversationalPipelineTests(unittest.TestCase):
|
|||||||
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM)
|
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM)
|
||||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||||
# Then
|
# Then
|
||||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
self.assertEqual(len(conversation_1.past_user_inputs), 1)
|
||||||
# When
|
# When
|
||||||
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
|
result = conversation_agent(conversation_1, do_sample=False, max_length=36)
|
||||||
# Then
|
# Then
|
||||||
@@ -379,8 +379,8 @@ These are just a few of the many attractions that Paris has to offer. With so mu
|
|||||||
conversation_1 = Conversation("My name is Sarah and I live in London")
|
conversation_1 = Conversation("My name is Sarah and I live in London")
|
||||||
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
|
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
|
||||||
# Then
|
# Then
|
||||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
self.assertEqual(len(conversation_1.past_user_inputs), 1)
|
||||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
self.assertEqual(len(conversation_2.past_user_inputs), 1)
|
||||||
# When
|
# When
|
||||||
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||||
# Then
|
# Then
|
||||||
|
|||||||
Reference in New Issue
Block a user