From b1aa4982cde8224fd7f3b05fd1c48adc923cfcfe Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 8 Feb 2021 12:29:07 +0100 Subject: [PATCH] Cleaning up `ConversationalPipeline` to support more than DialoGPT. (#10002) * Cleaning up `ConversationalPipeline` to support more than DialoGPT. Currently ConversationalPipeline was heavily biased towards DialoGPT ,which is the default model for this pipeline. This PR proposes changes to put back the modifications specific to DialoGPT into tokenizer-specific behavior wherever possible, by creating `_build_conversation_input_ids` function that takes conversation as input, and returns a list of ints corresponding to the tokens. It feels natural to put here because all models have probably different strategies to build input_ids from the full conversation and it's the tokenizer's job to transform strings into tokens (and vice-versa) If `_build_conversation_input_ids` is missing, previous behavior is used so we don't break anything so far (except for blenderbot where it's a fix). This PR also contains a fix for too long inputs. There used to be dead code for trying to limit the size of incoming input. The introduced fixed is that we limit within `_build_conversation_input_ids` to `tokenizer.model_max_length`. It corresponds to the intent of the removed dead code and is actually better because it corresponds to `model_max_length` which is different from `max_length` (which is a default parameter for `generate`). - Removed `history` logic from the Conversation as it's not relevant anymore because tokenization logic has been moved to tokenizer. And tokenizer cannot save any cache, and conversation cannot know what is relevant or not. Also it's not usable from `blenderbot` because the input_ids are not append only (EOS tokens is always at the end). - Added `iter_texts` method on `Conversation` because all the code was literred with some form of this iteration of past/generated_responses. * Removing torch mention in types. * Adding type checking to `_build_conversation_input_ids`. * Fixing import in strings. --- .../blenderbot/tokenization_blenderbot.py | 22 ++- .../models/gpt2/tokenization_gpt2.py | 13 +- .../models/gpt2/tokenization_gpt2_fast.py | 16 +- src/transformers/pipelines/conversational.py | 123 ++++-------- tests/test_pipelines_conversational.py | 178 ++++++++++-------- 5 files changed, 189 insertions(+), 163 deletions(-) diff --git a/src/transformers/models/blenderbot/tokenization_blenderbot.py b/src/transformers/models/blenderbot/tokenization_blenderbot.py index 725f31605d..ea8b435683 100644 --- a/src/transformers/models/blenderbot/tokenization_blenderbot.py +++ b/src/transformers/models/blenderbot/tokenization_blenderbot.py @@ -14,12 +14,15 @@ # limitations under the License. """Tokenization class for Blenderbot.""" -from typing import List +from typing import TYPE_CHECKING, List from ...utils import logging from ..roberta.tokenization_roberta import RobertaTokenizer +if TYPE_CHECKING: + from transformers.pipelines.conversational import Conversation + logger = logging.get_logger(__name__) @@ -74,6 +77,23 @@ class BlenderbotTokenizer(RobertaTokenizer): """ return token_ids_0 + [self.eos_token_id] + def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: + inputs = [] + for is_user, text in conversation.iter_texts(): + if is_user: + # We need to space prefix as it's being done within blenderbot + inputs.append(" " + text) + else: + # Generated responses should contain them already. + inputs.append(text) + + full_string = " ".join(inputs) + input_ids = self.encode(full_string) + if len(input_ids) > self.model_max_length: + input_ids = input_ids[-self.model_max_length :] + logger.warning(f"Trimmed input from conversation as it was longer than {self.model_max_length} tokens.") + return input_ids + def get_pairs(word): """ diff --git a/src/transformers/models/gpt2/tokenization_gpt2.py b/src/transformers/models/gpt2/tokenization_gpt2.py index 87f353b93c..4601f902e0 100644 --- a/src/transformers/models/gpt2/tokenization_gpt2.py +++ b/src/transformers/models/gpt2/tokenization_gpt2.py @@ -18,7 +18,7 @@ import json import os from functools import lru_cache -from typing import Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple import regex as re @@ -26,6 +26,9 @@ from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging +if TYPE_CHECKING: + from transformers.pipelines.conversational import Conversation + logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = { @@ -296,3 +299,11 @@ class GPT2Tokenizer(PreTrainedTokenizer): if is_split_into_words or add_prefix_space: text = " " + text return (text, kwargs) + + def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: + input_ids = [] + for is_user, text in conversation.iter_texts(): + input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) + if len(input_ids) > self.model_max_length: + input_ids = input_ids[-self.model_max_length :] + return input_ids diff --git a/src/transformers/models/gpt2/tokenization_gpt2_fast.py b/src/transformers/models/gpt2/tokenization_gpt2_fast.py index 1a9b643a6c..54356a52ec 100644 --- a/src/transformers/models/gpt2/tokenization_gpt2_fast.py +++ b/src/transformers/models/gpt2/tokenization_gpt2_fast.py @@ -16,7 +16,7 @@ import json -from typing import Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from tokenizers import pre_tokenizers @@ -26,6 +26,10 @@ from ...utils import logging from .tokenization_gpt2 import GPT2Tokenizer +if TYPE_CHECKING: + from transformers.pipelines.conversational import Conversation + + logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} @@ -171,3 +175,13 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast): def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: files = self._tokenizer.model.save(save_directory, name=filename_prefix) return tuple(files) + + def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: + """This corresponds to DialoGPT variants of models.""" + input_ids = [] + for is_user, text in conversation.iter_texts(): + input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) + + if len(input_ids) > self.model_max_length: + input_ids = input_ids[-self.model_max_length :] + return input_ids diff --git a/src/transformers/pipelines/conversational.py b/src/transformers/pipelines/conversational.py index 7e22b8b92b..0ab07eded7 100644 --- a/src/transformers/pipelines/conversational.py +++ b/src/transformers/pipelines/conversational.py @@ -1,8 +1,7 @@ import uuid -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available -from ..tokenization_utils import TruncationStrategy from ..utils import logging from .base import PIPELINE_INIT_ARGS, Pipeline @@ -70,8 +69,6 @@ class Conversation: self.past_user_inputs: List[str] = past_user_inputs self.generated_responses: List[str] = generated_responses self.new_user_input: Optional[str] = text - self._index: int = 0 - self._history: List[int] = [] def __eq__(self, other): if not isinstance(other, Conversation): @@ -128,6 +125,19 @@ class Conversation: """ self.generated_responses.append(response) + def iter_texts(self): + """ + Iterates over all blobs of the conversation. + + Retuns: Iterator of (is_user, text_chunk) in chronological order of the conversation. ``is_user`` is a + :obj:`bool`, ``text_chunks`` is a :obj:`str`. + """ + for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses): + yield True, user_input + yield False, generated_response + if self.new_user_input: + yield True, self.new_user_input + def __repr__(self): """ Generates a string representation of the conversation. @@ -139,11 +149,9 @@ class Conversation: suggestions? bot >> The Big Lebowski """ output = "Conversation id: {} \n".format(self.uuid) - for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses): - output += "user >> {} \n".format(user_input) - output += "bot >> {} \n".format(generated_response) - if self.new_user_input is not None: - output += "user >> {} \n".format(self.new_user_input) + for is_user, text in self.iter_texts(): + name = "user" if is_user else "bot" + output += "{} >> {} \n".format(name, text) return output @@ -191,34 +199,6 @@ class ConversationalPipeline(Pipeline): self.min_length_for_response = min_length_for_response - def _get_history(self, conversation): - """ - Private function (subject to change) that simply tokenizes and concatenates past inputs. Also saves that - tokenization into the conversation state. - - Args: - conversation (:class:`~transformers.Conversation`) - - Returns: - :obj:`List[int]`: The list of tokens for the past input of that conversation. - """ - # Make a copy to prevent messing cache up if there's an error - # within this function - history = conversation._history.copy() - index = conversation._index - new_index = index - for i, (past_user_input, generated_response) in enumerate( - zip(conversation.past_user_inputs[index:], conversation.generated_responses[index:]) - ): - for el in (past_user_input, generated_response): - new_history = self._parse_and_tokenize([el])[0] - history.extend(new_history) - new_index = i + index + 1 - conversation._index = new_index - conversation._history = history - # Hand back a copy to caller so they can't accidently modify our cache. - return history.copy() - def __call__( self, conversations: Union[Conversation, List[Conversation]], @@ -249,7 +229,7 @@ class ConversationalPipeline(Pipeline): for conversation in conversations: assert isinstance( conversation, Conversation - ), "DialoguePipeline expects a Conversation or list of Conversations as an input" + ), "ConversationalPipeline expects a Conversation or list of Conversations as an input" if conversation.new_user_input is None: raise ValueError( "Conversation with UUID {} does not contain new user input to process. " @@ -261,14 +241,11 @@ class ConversationalPipeline(Pipeline): self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None ), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input" else: - raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input") + raise ValueError("ConversationalPipeline expects a Conversation or list of Conversations as an input") with self.device_placement(): - inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations]) - histories = [self._get_history(conversation) for conversation in conversations] - max_length = generate_kwargs.get("max_length", self.model.config.max_length) - inputs = self._concat_inputs_history(inputs, histories, max_length) + inputs = self._parse_and_tokenize(conversations) if self.framework == "pt": inputs = self.ensure_tensor_on_device(**inputs) @@ -277,11 +254,6 @@ class ConversationalPipeline(Pipeline): elif self.framework == "tf": input_length = tf.shape(inputs["input_ids"])[-1].numpy() - if input_length > 0.9 * max_length: - logger.warning( - "Longest conversation length: {} is bigger than 0.9 * max_length: {}. " - "You might consider trimming the early phase of the conversation".format(input_length, max_length) - ) generated_responses = self.model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], @@ -318,18 +290,6 @@ class ConversationalPipeline(Pipeline): else: return output - def _parse_and_tokenize( - self, inputs, add_special_tokens=False, padding=False, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs - ): - """ - Parse arguments and tokenize, adding an EOS token at the end of the user input - """ - # Parse arguments - inputs = self.tokenizer(inputs, add_special_tokens=add_special_tokens, padding=padding).get("input_ids", []) - for input in inputs: - input.append(self.tokenizer.eos_token_id) - return inputs - def _clean_padding_history(self, generated_tensor) -> List[List[int]]: """ Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as @@ -363,28 +323,23 @@ class ConversationalPipeline(Pipeline): outputs.append(sequence_tokens) return outputs - def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Optional[List[int]]], max_length: int): - """ - Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context - """ - outputs = [] - for new_input, history in zip(inputs, histories): - if history is not None: - new_input = history + new_input - if len(new_input) > max_length - self.min_length_for_response: - cutoff_eos_index = 0 - while len(new_input) - cutoff_eos_index > max_length - self.min_length_for_response: - if cutoff_eos_index >= len(new_input): - break - cutoff_eos_index = new_input[cutoff_eos_index:].index(self.tokenizer.eos_token_id) - if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1: - break - else: - logger.warning( - f"Cutting history off because it's too long ({len(new_input)} > {max_length - self.min_length_for_response}) for underlying model" - ) - outputs.append(new_input) - padded_outputs = self.tokenizer.pad( - {"input_ids": outputs}, padding="longest", return_attention_mask=True, return_tensors=self.framework + def _legacy_parse_and_tokenize(self, conversation: List[Conversation]) -> List[int]: + eos_token_id = self.tokenizer.eos_token_id + input_ids = [] + for is_user, text in conversation.iter_texts(): + input_ids.extend(self.tokenizer.encode(text, add_special_tokens=False) + [eos_token_id]) + + if len(input_ids) > self.tokenizer.model_max_length: + input_ids = input_ids[-self.model_max_length :] + return input_ids + + def _parse_and_tokenize(self, conversations: List[Conversation]) -> Dict[str, Any]: + if hasattr(self.tokenizer, "_build_conversation_input_ids"): + input_ids = [self.tokenizer._build_conversation_input_ids(conversation) for conversation in conversations] + else: + # If the tokenizer cannot handle conversations, we default to only the old version + input_ids = [self._legacy_parse_and_tokenize(conversation) for conversation in conversations] + inputs = self.tokenizer.pad( + {"input_ids": input_ids}, padding="longest", return_attention_mask=True, return_tensors="pt" ) - return padded_outputs + return inputs diff --git a/tests/test_pipelines_conversational.py b/tests/test_pipelines_conversational.py index 276c801d64..4ea4d808a1 100644 --- a/tests/test_pipelines_conversational.py +++ b/tests/test_pipelines_conversational.py @@ -15,6 +15,7 @@ import unittest from transformers import ( + AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, Conversation, @@ -87,11 +88,7 @@ class SimpleConversationPipelineTests(unittest.TestCase): self.assertEqual(len(conversation_1.past_user_inputs), 0) self.assertEqual(len(conversation_2.past_user_inputs), 0) - with self.assertLogs("transformers", level="WARNING") as log: - result = conversation_agent([conversation_1, conversation_2], max_length=48) - self.assertEqual(len(log.output), 2) - self.assertIn("You might consider trimming the early phase of the conversation", log.output[0]) - self.assertIn("Setting `pad_token_id`", log.output[1]) + result = conversation_agent([conversation_1, conversation_2], max_length=48) # Two conversations in one pass self.assertEqual(result, [conversation_1, conversation_2]) @@ -111,12 +108,7 @@ class SimpleConversationPipelineTests(unittest.TestCase): # One conversation with history conversation_2.add_user_input("Why do you recommend it?") - with self.assertLogs("transformers", level="WARNING") as log: - result = conversation_agent(conversation_2, max_length=64) - self.assertEqual(len(log.output), 3) - self.assertIn("Cutting history off because it's too long", log.output[0]) - self.assertIn("You might consider trimming the early phase of the conversation", log.output[1]) - self.assertIn("Setting `pad_token_id`", log.output[2]) + result = conversation_agent(conversation_2, max_length=64) self.assertEqual(result, conversation_2) self.assertEqual( @@ -128,65 +120,6 @@ class SimpleConversationPipelineTests(unittest.TestCase): ), ) - @require_torch - def test_history_cache(self): - conversation_agent = self.get_pipeline() - conversation = Conversation( - "Why do you recommend it?", - past_user_inputs=["What's the last book you have read?"], - generated_responses=["b"], - ) - with self.assertLogs("transformers", level="WARNING") as log: - _ = conversation_agent(conversation, max_length=64) - self.assertEqual(len(log.output), 3) - self.assertIn("Cutting history off because it's too long (63 > 32) for underlying model", log.output[0]) - self.assertIn("63 is bigger than 0.9 * max_length: 64", log.output[1]) - self.assertIn("Setting `pad_token_id`", log.output[2]) - self.assertEqual(conversation._index, 1) - self.assertEqual( - conversation._history, - [ - 87, - 104, - 97, - 116, - 39, - 115, - 32, - 116, - 104, - 101, - 32, - 108, - 97, - 115, - 116, - 32, - 98, - 111, - 111, - 107, - 32, - 121, - 111, - 117, - 32, - 104, - 97, - 118, - 101, - 32, - 114, - 101, - 97, - 100, - 63, - 259, # EOS - 98, # b - 259, # EOS - ], - ) - class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): pipeline_task = "conversational" @@ -276,6 +209,102 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas self.assertEqual(result.past_user_inputs[1], "Is it an action movie?") self.assertEqual(result.generated_responses[1], "It's a comedy.") + @require_torch + @slow + def test_integration_torch_conversation_dialogpt_input_ids(self): + tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small") + model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small") + nlp = ConversationalPipeline(model=model, tokenizer=tokenizer) + + conversation_1 = Conversation("hello") + inputs = nlp._parse_and_tokenize([conversation_1]) + self.assertEqual(inputs["input_ids"].tolist(), [[31373, 50256]]) + + conversation_2 = Conversation("how are you ?", past_user_inputs=["hello"], generated_responses=["Hi there!"]) + inputs = nlp._parse_and_tokenize([conversation_2]) + self.assertEqual( + inputs["input_ids"].tolist(), [[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256]] + ) + + inputs = nlp._parse_and_tokenize([conversation_1, conversation_2]) + self.assertEqual( + inputs["input_ids"].tolist(), + [ + [31373, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256], + [31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256], + ], + ) + + @require_torch + @slow + def test_integration_torch_conversation_blenderbot_400M_input_ids(self): + tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") + model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill") + nlp = ConversationalPipeline(model=model, tokenizer=tokenizer) + + # test1 + conversation_1 = Conversation("hello") + inputs = nlp._parse_and_tokenize([conversation_1]) + self.assertEqual(inputs["input_ids"].tolist(), [[1710, 86, 2]]) + + # test2 + conversation_1 = Conversation( + "I like lasagne.", + past_user_inputs=["hello"], + generated_responses=[ + " Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie." + ], + ) + inputs = nlp._parse_and_tokenize([conversation_1]) + self.assertEqual( + inputs["input_ids"].tolist(), + [ + # This should be compared with the same conversation on ParlAI `safe_interactive` demo. + [ + 1710, # hello + 86, + 228, # Double space + 228, + 946, + 304, + 398, + 6881, + 558, + 964, + 38, + 452, + 315, + 265, + 6252, + 452, + 322, + 968, + 6884, + 3146, + 278, + 306, + 265, + 617, + 87, + 388, + 75, + 341, + 286, + 521, + 21, + 228, # Double space + 228, + 281, # I like lasagne. + 398, + 6881, + 558, + 964, + 21, + 2, # EOS + ] + ], + ) + @require_torch @slow def test_integration_torch_conversation_blenderbot_400M(self): @@ -295,11 +324,11 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas " Hello! How are you doing today? I just got back from a walk with my dog.", ) - conversation_1 = Conversation(" Lasagne hello") + conversation_1 = Conversation("Lasagne hello") result = nlp(conversation_1, encoder_no_repeat_ngram_size=3) self.assertEqual( result.generated_responses[0], - " Lasagne is my favorite Italian dish. Do you like lasagne?", + " Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie.", ) conversation_1 = Conversation( @@ -311,10 +340,7 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas ) self.assertEqual( result.generated_responses[0], - # ParlAI implementation output, we have a different one, but it's our - # second best, you can check by using num_return_sequences=10 - # " Hello! How are you? I'm just getting ready to go to work, how about you?", - " Lasagne is a traditional Italian dish consisting of a yeasted flatbread typically topped with tomato sauce and cheese.", + " Me too. I like how it can be topped with vegetables, meats, and condiments.", ) @require_torch