diff --git a/docs/source/main_classes/pipelines.rst b/docs/source/main_classes/pipelines.rst index ea51feb7ca..214858fb5a 100644 --- a/docs/source/main_classes/pipelines.rst +++ b/docs/source/main_classes/pipelines.rst @@ -71,3 +71,11 @@ TextGenerationPipeline ========================================== .. autoclass:: transformers.TextGenerationPipeline + + +ConversationalPipeline +========================================== + +.. autoclass:: transformers.Conversation + +.. autoclass:: transformers.ConversationalPipeline \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a0fc396e51..18f6d72cef 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -104,6 +104,8 @@ from .modeling_tf_pytorch_utils import ( # Pipelines from .pipelines import ( + Conversation, + ConversationalPipeline, CsvPipelineDataFormat, FeatureExtractionPipeline, FillMaskPipeline, diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 8eba3c8e9c..b40f734ef2 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -20,11 +20,13 @@ import logging import os import pickle import sys +import uuid from abc import ABC, abstractmethod from contextlib import contextmanager from itertools import chain from os.path import abspath, exists from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from uuid import UUID import numpy as np @@ -36,7 +38,7 @@ from .modelcard import ModelCard from .tokenization_auto import AutoTokenizer from .tokenization_bert import BasicTokenizer from .tokenization_utils import PreTrainedTokenizer -from .tokenization_utils_base import PaddingStrategy +from .tokenization_utils_base import BatchEncoding, PaddingStrategy if is_tf_available(): @@ -51,6 +53,7 @@ if is_tf_available(): TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, + TFAutoModelForCausalLM, ) if is_torch_available(): @@ -1895,6 +1898,321 @@ class TranslationPipeline(Pipeline): return results +class Conversation: + """ + Utility class containing a conversation and its history. This class is meant to be used as an input to the + :obj:`~transformers.ConversationalPipeline`. The conversation contains a number of utility function to manage the addition of new + user input and generated model responses. A conversation needs to contain an unprocessed user input before being + passed to the :obj:`~transformers.ConversationalPipeline`. This user input is either created when the class is instantiated, or by calling + `append_response("input")` after a conversation turn. + + Usage:: + + conversation = Conversation("Going to the movies tonight - any suggestions?") + + # Steps usually performed by the model when generating a response: + # 1. Mark the user input as processed (moved to the history) + conversation.mark_processed() + # 2. Append a mode response + conversation.append_response("The Big lebowski.") + + conversation.add_user_input("Is it good?") + + Arguments: + text (:obj:`str`, `optional`, defaults to :obj:`None`): + The initial user input to start the conversation. + If :obj:`None`, a user input needs to be provided manually using `add_user_input` before the conversation can begin. + conversation_id (:obj:`uuid.UUID`, `optional`, defaults to :obj:`None`): + Unique identifier for the conversation + If :obj:`None`, the random UUID4 id will be assigned to the conversation. + """ + + def __init__(self, text: str = None, conversation_id: UUID = None): + if not conversation_id: + conversation_id = uuid.uuid4() + self.uuid: UUID = conversation_id + self.past_user_inputs: List[str] = [] + self.generated_responses: List[str] = [] + self.history: List[int] = [] + self.new_user_input: Optional[str] = text + + def add_user_input(self, text: str, overwrite: bool = False): + """ + Add a user input to the conversation for the next round. This populates the internal `new_user_input` field. + + Args: + text: str, the user input for the next conversation round + overwrite: bool, flag indicating if existing and unprocessed user input should be overwritten when this function is called + + """ + if self.new_user_input: + if overwrite: + logger.warning( + 'User input added while unprocessed input was existing: "{}" was overwritten with: "{}".'.format( + self.new_user_input, text + ) + ) + self.new_user_input = text + else: + logger.warning( + 'User input added while unprocessed input was existing: "{}" new input ignored: "{}". ' + "Set `overwrite` to True to overwrite unprocessed user input".format(self.new_user_input, text) + ) + else: + self.new_user_input = text + + def mark_processed(self): + """ + Mark the conversation as processed (moves the content of `new_user_input` to `past_user_inputs`) and empties the + `new_user_input` field. + """ + if self.new_user_input: + self.past_user_inputs.append(self.new_user_input) + self.new_user_input = None + + def append_response(self, response: str): + """ + Append a response to the list of generated responses. + + Args: + response: str, the model generated response + """ + self.generated_responses.append(response) + + def set_history(self, history: List[int]): + """ + Updates the value of the history of the conversation. The history is represented by a list of `token_ids`. The + history is used by the model to generate responses based on the previous conversation turns. + + Args: + history: (list of int), history of tokens provided and generated for this conversation + """ + self.history = history + + def __repr__(self): + """ + Generates a string representation of the conversation. + + Return: + :obj:`str` or :obj:`Dict`: + + Example: + Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114 + user >> Going to the movies tonight - any suggestions? + bot >> The Big Lebowski + """ + output = "Conversation id: {} \n".format(self.uuid) + for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses): + output += "user >> {} \n".format(user_input) + output += "bot >> {} \n".format(generated_response) + if self.new_user_input is not None: + output += "user >> {} \n".format(self.new_user_input) + return output + + +class ConversationalPipeline(Pipeline): + """ + Multi-turn conversational pipeline. + + Usage:: + + conversational_pipeline = pipeline("conversational") + + conversation_1 = Conversation("Going to the movies tonight - any suggestions?") + conversation_2 = Conversation("What's the last book you have read?") + + conversational_pipeline([conversation_1, conversation_2]) + + conversation_1.add_user_input("Is it an action movie?") + conversation_2.add_user_input("What is the genre of this book?") + + conversational_pipeline([conversation_1, conversation_2]) + + The models that this pipeline can use are models that have been fine-tuned on a multi-turn conversational task, + currently: "microsoft/DialoGPT-small", "microsoft/DialoGPT-medium", "microsoft/DialoGPT-large" + See the up-to-date list of available models on + `huggingface.co/models `__. + + Arguments: + model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`): + The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string + checkpoint identifier or an actual pre-trained model inheriting from + :class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for + TensorFlow. + If :obj:`None`, the default of the pipeline will be loaded. + tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`): + The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`, + a string checkpoint identifier or an actual pre-trained tokenizer inheriting from + :class:`~transformers.PreTrainedTokenizer`. + If :obj:`None`, the default of the pipeline will be loaded. + modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`): + Model card attributed to the model for this pipeline. + framework (:obj:`str`, `optional`, defaults to :obj:`None`): + The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be + installed. + If no framework is specified, will default to the one currently installed. If no framework is specified + and both frameworks are installed, will default to PyTorch. + args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`): + Reference to the object in charge of parsing supplied pipeline parameters. + device (:obj:`int`, `optional`, defaults to :obj:`-1`): + Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model + on the associated CUDA device id. + """ + + def __init__(self, min_length_for_response=32, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.tokenizer.eos_token_id is not None, "DialoguePipeline tokenizer should have an EOS token set" + if self.tokenizer.pad_token_id is not None: + self.pad_token_id = self.tokenizer.pad_token_id + else: + self.pad_token_id = self.tokenizer.eos_token_id + self.min_length_for_response = min_length_for_response + + def __call__( + self, + conversations: Union[Conversation, List[Conversation]], + clean_up_tokenization_spaces=True, + **generate_kwargs + ): + r""" + Args: + conversations: (list of :class:`~transformers.pipelines.Conversation`) Conversations to generate responses for + **generate_kwargs: extra kwargs passed to `self.model.generate`_ + + Returns: + list of conversations with updated generated responses for those containing a new user input + """ + + # Input validation + if isinstance(conversations, list): + for conversation in conversations: + assert isinstance( + conversation, Conversation + ), "DialoguePipeline expects a Conversation or list of Conversations as an input" + if conversation.new_user_input is None: + raise ValueError( + "Conversation with UUID {} does not contain new user input to process. " + "Add user inputs with the conversation's `add_user_input` method".format( + type(conversation.uuid) + ) + ) + assert ( + self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None + ), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input" + elif isinstance(conversations, Conversation): + conversations = [conversations] + else: + raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input") + + with self.device_placement(): + + inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations]) + histories = [conversation.history for conversation in conversations] + max_length = generate_kwargs.get("max_length", self.model.config.max_length) + inputs = self._concat_inputs_history(inputs, histories, max_length) + + if self.framework == "pt": + inputs = self.ensure_tensor_on_device(**inputs) + input_length = inputs["input_ids"].shape[-1] + + elif self.framework == "tf": + input_length = tf.shape(inputs["input_ids"])[-1].numpy() + + if input_length > 0.9 * max_length: + logger.warning( + "Longest conversation length: {} is bigger than 0.9 * max_length: {}. " + "You might consider trimming the early phase of the conversation".format(input_length, max_length) + ) + generated_responses = self.model.generate( + inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs, + ) + + cleaned_history = self._clean_padding_history(generated_responses) + output = [] + for conversation_index, conversation in enumerate(conversations): + conversation.mark_processed() + conversation.generated_responses.append( + self.tokenizer.decode( + cleaned_history[conversation_index][input_length:], + skip_special_tokens=True, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + ) + ) + conversation.set_history(cleaned_history[conversation_index]) + output.append(conversation) + if len(output) == 1: + return output[0] + else: + return output + + def _parse_and_tokenize(self, *args, **kwargs): + """ + Parse arguments and tokenize, adding an EOS token at the end of the user input + """ + # Parse arguments + inputs = self._args_parser(*args, **kwargs) + inputs = self.tokenizer.batch_encode_plus(inputs, add_special_tokens=False, padding=False).get("input_ids", []) + for input in inputs: + input.append(self.tokenizer.eos_token_id) + return inputs + + def _clean_padding_history(self, generated_tensor) -> List[List[int]]: + """ + Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as + an input: + - at the end of the concatenated history and new user input, so that all input to the model have the same + length + - at the end of the generated response, as some responses will be longer than others + This method cleans up these padding token so that the history for each conversation is not impacted by the + batching process. + """ + outputs = [] + for sequence in generated_tensor: + sequence_tokens = [] + is_previous_pad = False + for token in sequence: + if token == self.pad_token_id: + if is_previous_pad: + continue + else: + is_previous_pad = True + else: + is_previous_pad = False + if self.framework == "pt": + sequence_tokens.append(token.item()) + else: + sequence_tokens.append(int(token.numpy())) + + outputs.append(sequence_tokens) + return outputs + + def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Optional[List[int]]], max_length: int): + """ + Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context + """ + outputs = [] + for new_input, history in zip(inputs, histories): + if history is not None: + new_input = history + new_input + if len(new_input) > max_length - self.min_length_for_response: + cutoff_eos_index = 0 + while len(new_input) - cutoff_eos_index > max_length - self.min_length_for_response: + if cutoff_eos_index >= len(new_input): + break + cutoff_eos_index = new_input[cutoff_eos_index:].index(self.tokenizer.eos_token_id) + if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1: + break + else: + new_input = new_input[cutoff_eos_index + 1 :] + outputs.append(new_input) + max_len = max([len(item) for item in outputs]) + outputs = [output + [self.pad_token_id] * (max_len - len(output)) for output in outputs] + outputs = BatchEncoding( + {"input_ids": outputs, "attention_mask": [1] * len(outputs)}, tensor_type=self.framework + ) + return outputs + + # Register all the supported tasks here SUPPORTED_TASKS = { "feature-extraction": { @@ -1979,6 +2297,12 @@ SUPPORTED_TASKS = { "tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"}, }, }, + "conversational": { + "impl": ConversationalPipeline, + "tf": TFAutoModelForCausalLM if is_tf_available() else None, + "pt": AutoModelForCausalLM if is_torch_available() else None, + "default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}}, + }, } diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 3f2dd55afb..cd11fcfb1c 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -2,7 +2,7 @@ import unittest from typing import Iterable, List, Optional from transformers import pipeline -from transformers.pipelines import SUPPORTED_TASKS, DefaultArgumentHandler, Pipeline +from transformers.pipelines import SUPPORTED_TASKS, Conversation, DefaultArgumentHandler, Pipeline from transformers.testing_utils import require_tf, require_torch, slow, torch_device @@ -28,6 +28,8 @@ TRANSLATION_FINETUNED_MODELS = [ ] TF_TRANSLATION_FINETUNED_MODELS = [("patrickvonplaten/t5-tiny-random", "translation_en_to_fr")] +DIALOGUE_FINETUNED_MODELS = ["microsoft/DialoGPT-medium"] + expected_fill_mask_result = [ [ {"sequence": "My name is John", "score": 0.00782308354973793, "token": 610, "token_str": "ĠJohn"}, @@ -314,6 +316,64 @@ class MonoColumnInputTestCase(unittest.TestCase): nlp = pipeline(task="text-generation", model=model_name, tokenizer=model_name, framework="tf") self._test_mono_column_pipeline(nlp, VALID_INPUTS, {}) + @slow + @require_torch + def test_integration_torch_conversation(self): + # When + nlp = pipeline(task="conversational", device=DEFAULT_DEVICE_NUM) + conversation_1 = Conversation("Going to the movies tonight - any suggestions?") + conversation_2 = Conversation("What's the last book you have read?") + # Then + self.assertEqual(len(conversation_1.past_user_inputs), 0) + self.assertEqual(len(conversation_2.past_user_inputs), 0) + # When + result = nlp([conversation_1, conversation_2], do_sample=False, max_length=1000) + # Then + self.assertEqual(result, [conversation_1, conversation_2]) + self.assertEqual(len(result[0].past_user_inputs), 1) + self.assertEqual(len(result[1].past_user_inputs), 1) + self.assertEqual(len(result[0].generated_responses), 1) + self.assertEqual(len(result[1].generated_responses), 1) + self.assertEqual(result[0].past_user_inputs[0], "Going to the movies tonight - any suggestions?") + self.assertEqual(result[0].generated_responses[0], "The Big Lebowski") + self.assertEqual(result[1].past_user_inputs[0], "What's the last book you have read?") + self.assertEqual(result[1].generated_responses[0], "The Last Question") + # When + conversation_2.add_user_input("Why do you recommend it?") + result = nlp(conversation_2, do_sample=False, max_length=1000) + # Then + self.assertEqual(result, conversation_2) + self.assertEqual(len(result.past_user_inputs), 2) + self.assertEqual(len(result.generated_responses), 2) + self.assertEqual(result.past_user_inputs[1], "Why do you recommend it?") + self.assertEqual(result.generated_responses[1], "It's a good book.") + + @slow + @require_torch + def test_integration_torch_conversation_truncated_history(self): + # When + nlp = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM) + conversation_1 = Conversation("Going to the movies tonight - any suggestions?") + # Then + self.assertEqual(len(conversation_1.past_user_inputs), 0) + # When + result = nlp(conversation_1, do_sample=False, max_length=36) + # Then + self.assertEqual(result, conversation_1) + self.assertEqual(len(result.past_user_inputs), 1) + self.assertEqual(len(result.generated_responses), 1) + self.assertEqual(result.past_user_inputs[0], "Going to the movies tonight - any suggestions?") + self.assertEqual(result.generated_responses[0], "The Big Lebowski") + # When + conversation_1.add_user_input("Is it an action movie?") + result = nlp(conversation_1, do_sample=False, max_length=36) + # Then + self.assertEqual(result, conversation_1) + self.assertEqual(len(result.past_user_inputs), 2) + self.assertEqual(len(result.generated_responses), 2) + self.assertEqual(result.past_user_inputs[1], "Is it an action movie?") + self.assertEqual(result.generated_responses[1], "It's a comedy.") + QA_FINETUNED_MODELS = ["sshleifer/tiny-distilbert-base-cased-distilled-squad"] @@ -450,6 +510,38 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase): self._test_zero_shot_pipeline_outputs(nlp) +class DialoguePipelineTests(unittest.TestCase): + def _test_conversation_pipeline(self, nlp): + valid_inputs = [Conversation("Hi there!"), [Conversation("Hi there!"), Conversation("How are you?")]] + invalid_inputs = ["Hi there!", Conversation()] + self.assertIsNotNone(nlp) + + mono_result = nlp(valid_inputs[0]) + self.assertIsInstance(mono_result, Conversation) + + multi_result = nlp(valid_inputs[1]) + self.assertIsInstance(multi_result, list) + self.assertIsInstance(multi_result[0], Conversation) + # Inactive conversations passed to the pipeline raise a ValueError + self.assertRaises(ValueError, nlp, valid_inputs[1]) + + for bad_input in invalid_inputs: + self.assertRaises(Exception, nlp, bad_input) + self.assertRaises(Exception, nlp, invalid_inputs) + + @require_torch + def test_torch_conversation(self): + for model_name in DIALOGUE_FINETUNED_MODELS: + nlp = pipeline(task="conversational", model=model_name, tokenizer=model_name) + self._test_conversation_pipeline(nlp) + + @require_tf + def test_tf_conversation(self): + for model_name in DIALOGUE_FINETUNED_MODELS: + nlp = pipeline(task="conversational", model=model_name, tokenizer=model_name, framework="tf") + self._test_conversation_pipeline(nlp) + + class QAPipelineTests(unittest.TestCase): def _test_qa_pipeline(self, nlp): output_keys = {"score", "answer", "start", "end"} @@ -593,7 +685,6 @@ class NerPipelineTests(unittest.TestCase): class PipelineCommonTests(unittest.TestCase): - pipelines = SUPPORTED_TASKS.keys() @slow