Cleaning up ConversationalPipeline to support more than DialoGPT. (#10002)
* Cleaning up `ConversationalPipeline` to support more than DialoGPT. Currently ConversationalPipeline was heavily biased towards DialoGPT ,which is the default model for this pipeline. This PR proposes changes to put back the modifications specific to DialoGPT into tokenizer-specific behavior wherever possible, by creating `_build_conversation_input_ids` function that takes conversation as input, and returns a list of ints corresponding to the tokens. It feels natural to put here because all models have probably different strategies to build input_ids from the full conversation and it's the tokenizer's job to transform strings into tokens (and vice-versa) If `_build_conversation_input_ids` is missing, previous behavior is used so we don't break anything so far (except for blenderbot where it's a fix). This PR also contains a fix for too long inputs. There used to be dead code for trying to limit the size of incoming input. The introduced fixed is that we limit within `_build_conversation_input_ids` to `tokenizer.model_max_length`. It corresponds to the intent of the removed dead code and is actually better because it corresponds to `model_max_length` which is different from `max_length` (which is a default parameter for `generate`). - Removed `history` logic from the Conversation as it's not relevant anymore because tokenization logic has been moved to tokenizer. And tokenizer cannot save any cache, and conversation cannot know what is relevant or not. Also it's not usable from `blenderbot` because the input_ids are not append only (EOS tokens is always at the end). - Added `iter_texts` method on `Conversation` because all the code was literred with some form of this iteration of past/generated_responses. * Removing torch mention in types. * Adding type checking to `_build_conversation_input_ids`. * Fixing import in strings.
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
Conversation,
|
||||
@@ -87,11 +88,7 @@ class SimpleConversationPipelineTests(unittest.TestCase):
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||
|
||||
with self.assertLogs("transformers", level="WARNING") as log:
|
||||
result = conversation_agent([conversation_1, conversation_2], max_length=48)
|
||||
self.assertEqual(len(log.output), 2)
|
||||
self.assertIn("You might consider trimming the early phase of the conversation", log.output[0])
|
||||
self.assertIn("Setting `pad_token_id`", log.output[1])
|
||||
result = conversation_agent([conversation_1, conversation_2], max_length=48)
|
||||
|
||||
# Two conversations in one pass
|
||||
self.assertEqual(result, [conversation_1, conversation_2])
|
||||
@@ -111,12 +108,7 @@ class SimpleConversationPipelineTests(unittest.TestCase):
|
||||
|
||||
# One conversation with history
|
||||
conversation_2.add_user_input("Why do you recommend it?")
|
||||
with self.assertLogs("transformers", level="WARNING") as log:
|
||||
result = conversation_agent(conversation_2, max_length=64)
|
||||
self.assertEqual(len(log.output), 3)
|
||||
self.assertIn("Cutting history off because it's too long", log.output[0])
|
||||
self.assertIn("You might consider trimming the early phase of the conversation", log.output[1])
|
||||
self.assertIn("Setting `pad_token_id`", log.output[2])
|
||||
result = conversation_agent(conversation_2, max_length=64)
|
||||
|
||||
self.assertEqual(result, conversation_2)
|
||||
self.assertEqual(
|
||||
@@ -128,65 +120,6 @@ class SimpleConversationPipelineTests(unittest.TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_history_cache(self):
|
||||
conversation_agent = self.get_pipeline()
|
||||
conversation = Conversation(
|
||||
"Why do you recommend it?",
|
||||
past_user_inputs=["What's the last book you have read?"],
|
||||
generated_responses=["b"],
|
||||
)
|
||||
with self.assertLogs("transformers", level="WARNING") as log:
|
||||
_ = conversation_agent(conversation, max_length=64)
|
||||
self.assertEqual(len(log.output), 3)
|
||||
self.assertIn("Cutting history off because it's too long (63 > 32) for underlying model", log.output[0])
|
||||
self.assertIn("63 is bigger than 0.9 * max_length: 64", log.output[1])
|
||||
self.assertIn("Setting `pad_token_id`", log.output[2])
|
||||
self.assertEqual(conversation._index, 1)
|
||||
self.assertEqual(
|
||||
conversation._history,
|
||||
[
|
||||
87,
|
||||
104,
|
||||
97,
|
||||
116,
|
||||
39,
|
||||
115,
|
||||
32,
|
||||
116,
|
||||
104,
|
||||
101,
|
||||
32,
|
||||
108,
|
||||
97,
|
||||
115,
|
||||
116,
|
||||
32,
|
||||
98,
|
||||
111,
|
||||
111,
|
||||
107,
|
||||
32,
|
||||
121,
|
||||
111,
|
||||
117,
|
||||
32,
|
||||
104,
|
||||
97,
|
||||
118,
|
||||
101,
|
||||
32,
|
||||
114,
|
||||
101,
|
||||
97,
|
||||
100,
|
||||
63,
|
||||
259, # EOS
|
||||
98, # b
|
||||
259, # EOS
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "conversational"
|
||||
@@ -276,6 +209,102 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
self.assertEqual(result.past_user_inputs[1], "Is it an action movie?")
|
||||
self.assertEqual(result.generated_responses[1], "It's a comedy.")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_conversation_dialogpt_input_ids(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
|
||||
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
|
||||
nlp = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
|
||||
conversation_1 = Conversation("hello")
|
||||
inputs = nlp._parse_and_tokenize([conversation_1])
|
||||
self.assertEqual(inputs["input_ids"].tolist(), [[31373, 50256]])
|
||||
|
||||
conversation_2 = Conversation("how are you ?", past_user_inputs=["hello"], generated_responses=["Hi there!"])
|
||||
inputs = nlp._parse_and_tokenize([conversation_2])
|
||||
self.assertEqual(
|
||||
inputs["input_ids"].tolist(), [[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256]]
|
||||
)
|
||||
|
||||
inputs = nlp._parse_and_tokenize([conversation_1, conversation_2])
|
||||
self.assertEqual(
|
||||
inputs["input_ids"].tolist(),
|
||||
[
|
||||
[31373, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
|
||||
[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256],
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_conversation_blenderbot_400M_input_ids(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
nlp = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
|
||||
# test1
|
||||
conversation_1 = Conversation("hello")
|
||||
inputs = nlp._parse_and_tokenize([conversation_1])
|
||||
self.assertEqual(inputs["input_ids"].tolist(), [[1710, 86, 2]])
|
||||
|
||||
# test2
|
||||
conversation_1 = Conversation(
|
||||
"I like lasagne.",
|
||||
past_user_inputs=["hello"],
|
||||
generated_responses=[
|
||||
" Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie."
|
||||
],
|
||||
)
|
||||
inputs = nlp._parse_and_tokenize([conversation_1])
|
||||
self.assertEqual(
|
||||
inputs["input_ids"].tolist(),
|
||||
[
|
||||
# This should be compared with the same conversation on ParlAI `safe_interactive` demo.
|
||||
[
|
||||
1710, # hello
|
||||
86,
|
||||
228, # Double space
|
||||
228,
|
||||
946,
|
||||
304,
|
||||
398,
|
||||
6881,
|
||||
558,
|
||||
964,
|
||||
38,
|
||||
452,
|
||||
315,
|
||||
265,
|
||||
6252,
|
||||
452,
|
||||
322,
|
||||
968,
|
||||
6884,
|
||||
3146,
|
||||
278,
|
||||
306,
|
||||
265,
|
||||
617,
|
||||
87,
|
||||
388,
|
||||
75,
|
||||
341,
|
||||
286,
|
||||
521,
|
||||
21,
|
||||
228, # Double space
|
||||
228,
|
||||
281, # I like lasagne.
|
||||
398,
|
||||
6881,
|
||||
558,
|
||||
964,
|
||||
21,
|
||||
2, # EOS
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_conversation_blenderbot_400M(self):
|
||||
@@ -295,11 +324,11 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
" Hello! How are you doing today? I just got back from a walk with my dog.",
|
||||
)
|
||||
|
||||
conversation_1 = Conversation(" Lasagne hello")
|
||||
conversation_1 = Conversation("Lasagne hello")
|
||||
result = nlp(conversation_1, encoder_no_repeat_ngram_size=3)
|
||||
self.assertEqual(
|
||||
result.generated_responses[0],
|
||||
" Lasagne is my favorite Italian dish. Do you like lasagne?",
|
||||
" Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie.",
|
||||
)
|
||||
|
||||
conversation_1 = Conversation(
|
||||
@@ -311,10 +340,7 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
||||
)
|
||||
self.assertEqual(
|
||||
result.generated_responses[0],
|
||||
# ParlAI implementation output, we have a different one, but it's our
|
||||
# second best, you can check by using num_return_sequences=10
|
||||
# " Hello! How are you? I'm just getting ready to go to work, how about you?",
|
||||
" Lasagne is a traditional Italian dish consisting of a yeasted flatbread typically topped with tomato sauce and cheese.",
|
||||
" Me too. I like how it can be topped with vegetables, meats, and condiments.",
|
||||
)
|
||||
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user