From f0a6057fbc3d46add1c15d1dfddfb098bec403c5 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 18 Sep 2023 15:08:56 +0100 Subject: [PATCH] Fix ConversationalPipeline tests (#26217) Add BlenderbotSmall templates and correct handling for conversation.past_user_inputs --- .../tokenization_blenderbot_small.py | 15 +++++++++++++++ .../tokenization_blenderbot_small_fast.py | 15 +++++++++++++++ tests/pipelines/test_pipelines_conversational.py | 10 +++++----- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py b/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py index e26cdfbd98..4acb873256 100644 --- a/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py @@ -236,3 +236,18 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer): index += 1 return vocab_file, merge_file + + @property + # Copied from transformers.models.blenderbot.tokenization_blenderbot.BlenderbotTokenizer.default_chat_template + def default_chat_template(self): + """ + A very simple chat template that just adds whitespace between messages. + """ + return ( + "{% for message in messages %}" + "{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}" + "{{ message['content'] }}" + "{% if not loop.last %}{{ ' ' }}{% endif %}" + "{% endfor %}" + "{{ eos_token }}" + ) diff --git a/src/transformers/models/blenderbot_small/tokenization_blenderbot_small_fast.py b/src/transformers/models/blenderbot_small/tokenization_blenderbot_small_fast.py index adc350f3d1..8daac3e04f 100644 --- a/src/transformers/models/blenderbot_small/tokenization_blenderbot_small_fast.py +++ b/src/transformers/models/blenderbot_small/tokenization_blenderbot_small_fast.py @@ -117,3 +117,18 @@ class BlenderbotSmallTokenizerFast(PreTrainedTokenizerFast): if token_ids_1 is None: return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + @property + # Copied from transformers.models.blenderbot.tokenization_blenderbot.BlenderbotTokenizer.default_chat_template + def default_chat_template(self): + """ + A very simple chat template that just adds whitespace between messages. + """ + return ( + "{% for message in messages %}" + "{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}" + "{{ message['content'] }}" + "{% if not loop.last %}{{ ' ' }}{% endif %}" + "{% endfor %}" + "{{ eos_token }}" + ) diff --git a/tests/pipelines/test_pipelines_conversational.py b/tests/pipelines/test_pipelines_conversational.py index dfc42ea481..2f6ba61340 100644 --- a/tests/pipelines/test_pipelines_conversational.py +++ b/tests/pipelines/test_pipelines_conversational.py @@ -140,8 +140,8 @@ class ConversationalPipelineTests(unittest.TestCase): conversation_1 = Conversation("Going to the movies tonight - any suggestions?") conversation_2 = Conversation("What's the last book you have read?") # Then - self.assertEqual(len(conversation_1.past_user_inputs), 0) - self.assertEqual(len(conversation_2.past_user_inputs), 0) + self.assertEqual(len(conversation_1.past_user_inputs), 1) + self.assertEqual(len(conversation_2.past_user_inputs), 1) # When result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000) # Then @@ -171,7 +171,7 @@ class ConversationalPipelineTests(unittest.TestCase): conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM) conversation_1 = Conversation("Going to the movies tonight - any suggestions?") # Then - self.assertEqual(len(conversation_1.past_user_inputs), 0) + self.assertEqual(len(conversation_1.past_user_inputs), 1) # When result = conversation_agent(conversation_1, do_sample=False, max_length=36) # Then @@ -379,8 +379,8 @@ These are just a few of the many attractions that Paris has to offer. With so mu conversation_1 = Conversation("My name is Sarah and I live in London") conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ") # Then - self.assertEqual(len(conversation_1.past_user_inputs), 0) - self.assertEqual(len(conversation_2.past_user_inputs), 0) + self.assertEqual(len(conversation_1.past_user_inputs), 1) + self.assertEqual(len(conversation_2.past_user_inputs), 1) # When result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000) # Then