Addition of a DialoguePipeline (#5516)
* initial commit for pipeline implementation Addition of input processing and history concatenation * Conversation pipeline tested and working for single & multiple conversation inputs * Added docstrings for dialogue pipeline * Addition of dialogue pipeline integration tests * Delete test_t5.py * Fixed max code length * Updated styling * Fixed test broken by formatting tools * Removed unused import * Added unit test for DialoguePipeline * Fixed Tensorflow compatibility * Fixed multi-framework support using framework flag * - Fixed docstring - Added `min_length_for_response` as an initialization parameter - Renamed `*args` to `conversations`, `conversations` being a `Conversation` or a `List[Conversation]` - Updated truncation to truncate entire segments of conversations, instead of cutting in the middle of a user/bot input * - renamed pipeline name from dialogue to conversational - removed hardcoded default value of 1000 and use config.max_length instead - added `append_response` and `set_history` method to the Conversation class to avoid direct fields mutation - fixed bug in history truncation method * - Updated ConversationalPipeline to accept only active conversations (otherwise a ValueError is raised) * - Simplified input tensor conversion * - Updated attention_mask value for Tensorflow compatibility * - Updated last dialogue reference to conversational & fixed integration tests * Fixed conflict with master * Updates following review comments * Updated formatting * Added Conversation and ConversationalPipeline to the library __init__, addition of docstrings for Conversation, added both to the docs * Update src/transformers/pipelines.py Updated docsting following review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -71,3 +71,11 @@ TextGenerationPipeline
|
|||||||
==========================================
|
==========================================
|
||||||
|
|
||||||
.. autoclass:: transformers.TextGenerationPipeline
|
.. autoclass:: transformers.TextGenerationPipeline
|
||||||
|
|
||||||
|
|
||||||
|
ConversationalPipeline
|
||||||
|
==========================================
|
||||||
|
|
||||||
|
.. autoclass:: transformers.Conversation
|
||||||
|
|
||||||
|
.. autoclass:: transformers.ConversationalPipeline
|
||||||
@@ -104,6 +104,8 @@ from .modeling_tf_pytorch_utils import (
|
|||||||
|
|
||||||
# Pipelines
|
# Pipelines
|
||||||
from .pipelines import (
|
from .pipelines import (
|
||||||
|
Conversation,
|
||||||
|
ConversationalPipeline,
|
||||||
CsvPipelineDataFormat,
|
CsvPipelineDataFormat,
|
||||||
FeatureExtractionPipeline,
|
FeatureExtractionPipeline,
|
||||||
FillMaskPipeline,
|
FillMaskPipeline,
|
||||||
|
|||||||
@@ -20,11 +20,13 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import sys
|
import sys
|
||||||
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from os.path import abspath, exists
|
from os.path import abspath, exists
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -36,7 +38,7 @@ from .modelcard import ModelCard
|
|||||||
from .tokenization_auto import AutoTokenizer
|
from .tokenization_auto import AutoTokenizer
|
||||||
from .tokenization_bert import BasicTokenizer
|
from .tokenization_bert import BasicTokenizer
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
from .tokenization_utils_base import PaddingStrategy
|
from .tokenization_utils_base import BatchEncoding, PaddingStrategy
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -51,6 +53,7 @@ if is_tf_available():
|
|||||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
TFAutoModelForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -1895,6 +1898,321 @@ class TranslationPipeline(Pipeline):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class Conversation:
|
||||||
|
"""
|
||||||
|
Utility class containing a conversation and its history. This class is meant to be used as an input to the
|
||||||
|
:obj:`~transformers.ConversationalPipeline`. The conversation contains a number of utility function to manage the addition of new
|
||||||
|
user input and generated model responses. A conversation needs to contain an unprocessed user input before being
|
||||||
|
passed to the :obj:`~transformers.ConversationalPipeline`. This user input is either created when the class is instantiated, or by calling
|
||||||
|
`append_response("input")` after a conversation turn.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
conversation = Conversation("Going to the movies tonight - any suggestions?")
|
||||||
|
|
||||||
|
# Steps usually performed by the model when generating a response:
|
||||||
|
# 1. Mark the user input as processed (moved to the history)
|
||||||
|
conversation.mark_processed()
|
||||||
|
# 2. Append a mode response
|
||||||
|
conversation.append_response("The Big lebowski.")
|
||||||
|
|
||||||
|
conversation.add_user_input("Is it good?")
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
text (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||||
|
The initial user input to start the conversation.
|
||||||
|
If :obj:`None`, a user input needs to be provided manually using `add_user_input` before the conversation can begin.
|
||||||
|
conversation_id (:obj:`uuid.UUID`, `optional`, defaults to :obj:`None`):
|
||||||
|
Unique identifier for the conversation
|
||||||
|
If :obj:`None`, the random UUID4 id will be assigned to the conversation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, text: str = None, conversation_id: UUID = None):
|
||||||
|
if not conversation_id:
|
||||||
|
conversation_id = uuid.uuid4()
|
||||||
|
self.uuid: UUID = conversation_id
|
||||||
|
self.past_user_inputs: List[str] = []
|
||||||
|
self.generated_responses: List[str] = []
|
||||||
|
self.history: List[int] = []
|
||||||
|
self.new_user_input: Optional[str] = text
|
||||||
|
|
||||||
|
def add_user_input(self, text: str, overwrite: bool = False):
|
||||||
|
"""
|
||||||
|
Add a user input to the conversation for the next round. This populates the internal `new_user_input` field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: str, the user input for the next conversation round
|
||||||
|
overwrite: bool, flag indicating if existing and unprocessed user input should be overwritten when this function is called
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.new_user_input:
|
||||||
|
if overwrite:
|
||||||
|
logger.warning(
|
||||||
|
'User input added while unprocessed input was existing: "{}" was overwritten with: "{}".'.format(
|
||||||
|
self.new_user_input, text
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.new_user_input = text
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
'User input added while unprocessed input was existing: "{}" new input ignored: "{}". '
|
||||||
|
"Set `overwrite` to True to overwrite unprocessed user input".format(self.new_user_input, text)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.new_user_input = text
|
||||||
|
|
||||||
|
def mark_processed(self):
|
||||||
|
"""
|
||||||
|
Mark the conversation as processed (moves the content of `new_user_input` to `past_user_inputs`) and empties the
|
||||||
|
`new_user_input` field.
|
||||||
|
"""
|
||||||
|
if self.new_user_input:
|
||||||
|
self.past_user_inputs.append(self.new_user_input)
|
||||||
|
self.new_user_input = None
|
||||||
|
|
||||||
|
def append_response(self, response: str):
|
||||||
|
"""
|
||||||
|
Append a response to the list of generated responses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: str, the model generated response
|
||||||
|
"""
|
||||||
|
self.generated_responses.append(response)
|
||||||
|
|
||||||
|
def set_history(self, history: List[int]):
|
||||||
|
"""
|
||||||
|
Updates the value of the history of the conversation. The history is represented by a list of `token_ids`. The
|
||||||
|
history is used by the model to generate responses based on the previous conversation turns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
history: (list of int), history of tokens provided and generated for this conversation
|
||||||
|
"""
|
||||||
|
self.history = history
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""
|
||||||
|
Generates a string representation of the conversation.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`str` or :obj:`Dict`:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114
|
||||||
|
user >> Going to the movies tonight - any suggestions?
|
||||||
|
bot >> The Big Lebowski
|
||||||
|
"""
|
||||||
|
output = "Conversation id: {} \n".format(self.uuid)
|
||||||
|
for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses):
|
||||||
|
output += "user >> {} \n".format(user_input)
|
||||||
|
output += "bot >> {} \n".format(generated_response)
|
||||||
|
if self.new_user_input is not None:
|
||||||
|
output += "user >> {} \n".format(self.new_user_input)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationalPipeline(Pipeline):
|
||||||
|
"""
|
||||||
|
Multi-turn conversational pipeline.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
conversational_pipeline = pipeline("conversational")
|
||||||
|
|
||||||
|
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||||
|
conversation_2 = Conversation("What's the last book you have read?")
|
||||||
|
|
||||||
|
conversational_pipeline([conversation_1, conversation_2])
|
||||||
|
|
||||||
|
conversation_1.add_user_input("Is it an action movie?")
|
||||||
|
conversation_2.add_user_input("What is the genre of this book?")
|
||||||
|
|
||||||
|
conversational_pipeline([conversation_1, conversation_2])
|
||||||
|
|
||||||
|
The models that this pipeline can use are models that have been fine-tuned on a multi-turn conversational task,
|
||||||
|
currently: "microsoft/DialoGPT-small", "microsoft/DialoGPT-medium", "microsoft/DialoGPT-large"
|
||||||
|
See the up-to-date list of available models on
|
||||||
|
`huggingface.co/models <https://huggingface.co/models?filter=conversational>`__.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
||||||
|
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
||||||
|
checkpoint identifier or an actual pre-trained model inheriting from
|
||||||
|
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||||
|
TensorFlow.
|
||||||
|
If :obj:`None`, the default of the pipeline will be loaded.
|
||||||
|
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
||||||
|
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
||||||
|
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
||||||
|
:class:`~transformers.PreTrainedTokenizer`.
|
||||||
|
If :obj:`None`, the default of the pipeline will be loaded.
|
||||||
|
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||||
|
Model card attributed to the model for this pipeline.
|
||||||
|
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||||
|
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
|
||||||
|
installed.
|
||||||
|
If no framework is specified, will default to the one currently installed. If no framework is specified
|
||||||
|
and both frameworks are installed, will default to PyTorch.
|
||||||
|
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
|
||||||
|
Reference to the object in charge of parsing supplied pipeline parameters.
|
||||||
|
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
|
||||||
|
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
|
||||||
|
on the associated CUDA device id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, min_length_for_response=32, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
assert self.tokenizer.eos_token_id is not None, "DialoguePipeline tokenizer should have an EOS token set"
|
||||||
|
if self.tokenizer.pad_token_id is not None:
|
||||||
|
self.pad_token_id = self.tokenizer.pad_token_id
|
||||||
|
else:
|
||||||
|
self.pad_token_id = self.tokenizer.eos_token_id
|
||||||
|
self.min_length_for_response = min_length_for_response
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
conversations: Union[Conversation, List[Conversation]],
|
||||||
|
clean_up_tokenization_spaces=True,
|
||||||
|
**generate_kwargs
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
conversations: (list of :class:`~transformers.pipelines.Conversation`) Conversations to generate responses for
|
||||||
|
**generate_kwargs: extra kwargs passed to `self.model.generate`_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of conversations with updated generated responses for those containing a new user input
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Input validation
|
||||||
|
if isinstance(conversations, list):
|
||||||
|
for conversation in conversations:
|
||||||
|
assert isinstance(
|
||||||
|
conversation, Conversation
|
||||||
|
), "DialoguePipeline expects a Conversation or list of Conversations as an input"
|
||||||
|
if conversation.new_user_input is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Conversation with UUID {} does not contain new user input to process. "
|
||||||
|
"Add user inputs with the conversation's `add_user_input` method".format(
|
||||||
|
type(conversation.uuid)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
|
||||||
|
), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
|
||||||
|
elif isinstance(conversations, Conversation):
|
||||||
|
conversations = [conversations]
|
||||||
|
else:
|
||||||
|
raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input")
|
||||||
|
|
||||||
|
with self.device_placement():
|
||||||
|
|
||||||
|
inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations])
|
||||||
|
histories = [conversation.history for conversation in conversations]
|
||||||
|
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||||
|
inputs = self._concat_inputs_history(inputs, histories, max_length)
|
||||||
|
|
||||||
|
if self.framework == "pt":
|
||||||
|
inputs = self.ensure_tensor_on_device(**inputs)
|
||||||
|
input_length = inputs["input_ids"].shape[-1]
|
||||||
|
|
||||||
|
elif self.framework == "tf":
|
||||||
|
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
|
||||||
|
|
||||||
|
if input_length > 0.9 * max_length:
|
||||||
|
logger.warning(
|
||||||
|
"Longest conversation length: {} is bigger than 0.9 * max_length: {}. "
|
||||||
|
"You might consider trimming the early phase of the conversation".format(input_length, max_length)
|
||||||
|
)
|
||||||
|
generated_responses = self.model.generate(
|
||||||
|
inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
cleaned_history = self._clean_padding_history(generated_responses)
|
||||||
|
output = []
|
||||||
|
for conversation_index, conversation in enumerate(conversations):
|
||||||
|
conversation.mark_processed()
|
||||||
|
conversation.generated_responses.append(
|
||||||
|
self.tokenizer.decode(
|
||||||
|
cleaned_history[conversation_index][input_length:],
|
||||||
|
skip_special_tokens=True,
|
||||||
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conversation.set_history(cleaned_history[conversation_index])
|
||||||
|
output.append(conversation)
|
||||||
|
if len(output) == 1:
|
||||||
|
return output[0]
|
||||||
|
else:
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _parse_and_tokenize(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Parse arguments and tokenize, adding an EOS token at the end of the user input
|
||||||
|
"""
|
||||||
|
# Parse arguments
|
||||||
|
inputs = self._args_parser(*args, **kwargs)
|
||||||
|
inputs = self.tokenizer.batch_encode_plus(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
|
||||||
|
for input in inputs:
|
||||||
|
input.append(self.tokenizer.eos_token_id)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def _clean_padding_history(self, generated_tensor) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as
|
||||||
|
an input:
|
||||||
|
- at the end of the concatenated history and new user input, so that all input to the model have the same
|
||||||
|
length
|
||||||
|
- at the end of the generated response, as some responses will be longer than others
|
||||||
|
This method cleans up these padding token so that the history for each conversation is not impacted by the
|
||||||
|
batching process.
|
||||||
|
"""
|
||||||
|
outputs = []
|
||||||
|
for sequence in generated_tensor:
|
||||||
|
sequence_tokens = []
|
||||||
|
is_previous_pad = False
|
||||||
|
for token in sequence:
|
||||||
|
if token == self.pad_token_id:
|
||||||
|
if is_previous_pad:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
is_previous_pad = True
|
||||||
|
else:
|
||||||
|
is_previous_pad = False
|
||||||
|
if self.framework == "pt":
|
||||||
|
sequence_tokens.append(token.item())
|
||||||
|
else:
|
||||||
|
sequence_tokens.append(int(token.numpy()))
|
||||||
|
|
||||||
|
outputs.append(sequence_tokens)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Optional[List[int]]], max_length: int):
|
||||||
|
"""
|
||||||
|
Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context
|
||||||
|
"""
|
||||||
|
outputs = []
|
||||||
|
for new_input, history in zip(inputs, histories):
|
||||||
|
if history is not None:
|
||||||
|
new_input = history + new_input
|
||||||
|
if len(new_input) > max_length - self.min_length_for_response:
|
||||||
|
cutoff_eos_index = 0
|
||||||
|
while len(new_input) - cutoff_eos_index > max_length - self.min_length_for_response:
|
||||||
|
if cutoff_eos_index >= len(new_input):
|
||||||
|
break
|
||||||
|
cutoff_eos_index = new_input[cutoff_eos_index:].index(self.tokenizer.eos_token_id)
|
||||||
|
if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
new_input = new_input[cutoff_eos_index + 1 :]
|
||||||
|
outputs.append(new_input)
|
||||||
|
max_len = max([len(item) for item in outputs])
|
||||||
|
outputs = [output + [self.pad_token_id] * (max_len - len(output)) for output in outputs]
|
||||||
|
outputs = BatchEncoding(
|
||||||
|
{"input_ids": outputs, "attention_mask": [1] * len(outputs)}, tensor_type=self.framework
|
||||||
|
)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Register all the supported tasks here
|
# Register all the supported tasks here
|
||||||
SUPPORTED_TASKS = {
|
SUPPORTED_TASKS = {
|
||||||
"feature-extraction": {
|
"feature-extraction": {
|
||||||
@@ -1979,6 +2297,12 @@ SUPPORTED_TASKS = {
|
|||||||
"tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
"tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"conversational": {
|
||||||
|
"impl": ConversationalPipeline,
|
||||||
|
"tf": TFAutoModelForCausalLM if is_tf_available() else None,
|
||||||
|
"pt": AutoModelForCausalLM if is_torch_available() else None,
|
||||||
|
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import unittest
|
|||||||
from typing import Iterable, List, Optional
|
from typing import Iterable, List, Optional
|
||||||
|
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
from transformers.pipelines import SUPPORTED_TASKS, DefaultArgumentHandler, Pipeline
|
from transformers.pipelines import SUPPORTED_TASKS, Conversation, DefaultArgumentHandler, Pipeline
|
||||||
from transformers.testing_utils import require_tf, require_torch, slow, torch_device
|
from transformers.testing_utils import require_tf, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -28,6 +28,8 @@ TRANSLATION_FINETUNED_MODELS = [
|
|||||||
]
|
]
|
||||||
TF_TRANSLATION_FINETUNED_MODELS = [("patrickvonplaten/t5-tiny-random", "translation_en_to_fr")]
|
TF_TRANSLATION_FINETUNED_MODELS = [("patrickvonplaten/t5-tiny-random", "translation_en_to_fr")]
|
||||||
|
|
||||||
|
DIALOGUE_FINETUNED_MODELS = ["microsoft/DialoGPT-medium"]
|
||||||
|
|
||||||
expected_fill_mask_result = [
|
expected_fill_mask_result = [
|
||||||
[
|
[
|
||||||
{"sequence": "<s>My name is John</s>", "score": 0.00782308354973793, "token": 610, "token_str": "ĠJohn"},
|
{"sequence": "<s>My name is John</s>", "score": 0.00782308354973793, "token": 610, "token_str": "ĠJohn"},
|
||||||
@@ -314,6 +316,64 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
nlp = pipeline(task="text-generation", model=model_name, tokenizer=model_name, framework="tf")
|
nlp = pipeline(task="text-generation", model=model_name, tokenizer=model_name, framework="tf")
|
||||||
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {})
|
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {})
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_integration_torch_conversation(self):
|
||||||
|
# When
|
||||||
|
nlp = pipeline(task="conversational", device=DEFAULT_DEVICE_NUM)
|
||||||
|
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||||
|
conversation_2 = Conversation("What's the last book you have read?")
|
||||||
|
# Then
|
||||||
|
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||||
|
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||||
|
# When
|
||||||
|
result = nlp([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||||
|
# Then
|
||||||
|
self.assertEqual(result, [conversation_1, conversation_2])
|
||||||
|
self.assertEqual(len(result[0].past_user_inputs), 1)
|
||||||
|
self.assertEqual(len(result[1].past_user_inputs), 1)
|
||||||
|
self.assertEqual(len(result[0].generated_responses), 1)
|
||||||
|
self.assertEqual(len(result[1].generated_responses), 1)
|
||||||
|
self.assertEqual(result[0].past_user_inputs[0], "Going to the movies tonight - any suggestions?")
|
||||||
|
self.assertEqual(result[0].generated_responses[0], "The Big Lebowski")
|
||||||
|
self.assertEqual(result[1].past_user_inputs[0], "What's the last book you have read?")
|
||||||
|
self.assertEqual(result[1].generated_responses[0], "The Last Question")
|
||||||
|
# When
|
||||||
|
conversation_2.add_user_input("Why do you recommend it?")
|
||||||
|
result = nlp(conversation_2, do_sample=False, max_length=1000)
|
||||||
|
# Then
|
||||||
|
self.assertEqual(result, conversation_2)
|
||||||
|
self.assertEqual(len(result.past_user_inputs), 2)
|
||||||
|
self.assertEqual(len(result.generated_responses), 2)
|
||||||
|
self.assertEqual(result.past_user_inputs[1], "Why do you recommend it?")
|
||||||
|
self.assertEqual(result.generated_responses[1], "It's a good book.")
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_integration_torch_conversation_truncated_history(self):
|
||||||
|
# When
|
||||||
|
nlp = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM)
|
||||||
|
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||||
|
# Then
|
||||||
|
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||||
|
# When
|
||||||
|
result = nlp(conversation_1, do_sample=False, max_length=36)
|
||||||
|
# Then
|
||||||
|
self.assertEqual(result, conversation_1)
|
||||||
|
self.assertEqual(len(result.past_user_inputs), 1)
|
||||||
|
self.assertEqual(len(result.generated_responses), 1)
|
||||||
|
self.assertEqual(result.past_user_inputs[0], "Going to the movies tonight - any suggestions?")
|
||||||
|
self.assertEqual(result.generated_responses[0], "The Big Lebowski")
|
||||||
|
# When
|
||||||
|
conversation_1.add_user_input("Is it an action movie?")
|
||||||
|
result = nlp(conversation_1, do_sample=False, max_length=36)
|
||||||
|
# Then
|
||||||
|
self.assertEqual(result, conversation_1)
|
||||||
|
self.assertEqual(len(result.past_user_inputs), 2)
|
||||||
|
self.assertEqual(len(result.generated_responses), 2)
|
||||||
|
self.assertEqual(result.past_user_inputs[1], "Is it an action movie?")
|
||||||
|
self.assertEqual(result.generated_responses[1], "It's a comedy.")
|
||||||
|
|
||||||
|
|
||||||
QA_FINETUNED_MODELS = ["sshleifer/tiny-distilbert-base-cased-distilled-squad"]
|
QA_FINETUNED_MODELS = ["sshleifer/tiny-distilbert-base-cased-distilled-squad"]
|
||||||
|
|
||||||
@@ -450,6 +510,38 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase):
|
|||||||
self._test_zero_shot_pipeline_outputs(nlp)
|
self._test_zero_shot_pipeline_outputs(nlp)
|
||||||
|
|
||||||
|
|
||||||
|
class DialoguePipelineTests(unittest.TestCase):
|
||||||
|
def _test_conversation_pipeline(self, nlp):
|
||||||
|
valid_inputs = [Conversation("Hi there!"), [Conversation("Hi there!"), Conversation("How are you?")]]
|
||||||
|
invalid_inputs = ["Hi there!", Conversation()]
|
||||||
|
self.assertIsNotNone(nlp)
|
||||||
|
|
||||||
|
mono_result = nlp(valid_inputs[0])
|
||||||
|
self.assertIsInstance(mono_result, Conversation)
|
||||||
|
|
||||||
|
multi_result = nlp(valid_inputs[1])
|
||||||
|
self.assertIsInstance(multi_result, list)
|
||||||
|
self.assertIsInstance(multi_result[0], Conversation)
|
||||||
|
# Inactive conversations passed to the pipeline raise a ValueError
|
||||||
|
self.assertRaises(ValueError, nlp, valid_inputs[1])
|
||||||
|
|
||||||
|
for bad_input in invalid_inputs:
|
||||||
|
self.assertRaises(Exception, nlp, bad_input)
|
||||||
|
self.assertRaises(Exception, nlp, invalid_inputs)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_torch_conversation(self):
|
||||||
|
for model_name in DIALOGUE_FINETUNED_MODELS:
|
||||||
|
nlp = pipeline(task="conversational", model=model_name, tokenizer=model_name)
|
||||||
|
self._test_conversation_pipeline(nlp)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_tf_conversation(self):
|
||||||
|
for model_name in DIALOGUE_FINETUNED_MODELS:
|
||||||
|
nlp = pipeline(task="conversational", model=model_name, tokenizer=model_name, framework="tf")
|
||||||
|
self._test_conversation_pipeline(nlp)
|
||||||
|
|
||||||
|
|
||||||
class QAPipelineTests(unittest.TestCase):
|
class QAPipelineTests(unittest.TestCase):
|
||||||
def _test_qa_pipeline(self, nlp):
|
def _test_qa_pipeline(self, nlp):
|
||||||
output_keys = {"score", "answer", "start", "end"}
|
output_keys = {"score", "answer", "start", "end"}
|
||||||
@@ -593,7 +685,6 @@ class NerPipelineTests(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class PipelineCommonTests(unittest.TestCase):
|
class PipelineCommonTests(unittest.TestCase):
|
||||||
|
|
||||||
pipelines = SUPPORTED_TASKS.keys()
|
pipelines = SUPPORTED_TASKS.keys()
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
Reference in New Issue
Block a user