Addition of a DialoguePipeline (#5516)
* initial commit for pipeline implementation Addition of input processing and history concatenation * Conversation pipeline tested and working for single & multiple conversation inputs * Added docstrings for dialogue pipeline * Addition of dialogue pipeline integration tests * Delete test_t5.py * Fixed max code length * Updated styling * Fixed test broken by formatting tools * Removed unused import * Added unit test for DialoguePipeline * Fixed Tensorflow compatibility * Fixed multi-framework support using framework flag * - Fixed docstring - Added `min_length_for_response` as an initialization parameter - Renamed `*args` to `conversations`, `conversations` being a `Conversation` or a `List[Conversation]` - Updated truncation to truncate entire segments of conversations, instead of cutting in the middle of a user/bot input * - renamed pipeline name from dialogue to conversational - removed hardcoded default value of 1000 and use config.max_length instead - added `append_response` and `set_history` method to the Conversation class to avoid direct fields mutation - fixed bug in history truncation method * - Updated ConversationalPipeline to accept only active conversations (otherwise a ValueError is raised) * - Simplified input tensor conversion * - Updated attention_mask value for Tensorflow compatibility * - Updated last dialogue reference to conversational & fixed integration tests * Fixed conflict with master * Updates following review comments * Updated formatting * Added Conversation and ConversationalPipeline to the library __init__, addition of docstrings for Conversation, added both to the docs * Update src/transformers/pipelines.py Updated docsting following review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -71,3 +71,11 @@ TextGenerationPipeline
|
||||
==========================================
|
||||
|
||||
.. autoclass:: transformers.TextGenerationPipeline
|
||||
|
||||
|
||||
ConversationalPipeline
|
||||
==========================================
|
||||
|
||||
.. autoclass:: transformers.Conversation
|
||||
|
||||
.. autoclass:: transformers.ConversationalPipeline
|
||||
@@ -104,6 +104,8 @@ from .modeling_tf_pytorch_utils import (
|
||||
|
||||
# Pipelines
|
||||
from .pipelines import (
|
||||
Conversation,
|
||||
ConversationalPipeline,
|
||||
CsvPipelineDataFormat,
|
||||
FeatureExtractionPipeline,
|
||||
FillMaskPipeline,
|
||||
|
||||
@@ -20,11 +20,13 @@ import logging
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from itertools import chain
|
||||
from os.path import abspath, exists
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -36,7 +38,7 @@ from .modelcard import ModelCard
|
||||
from .tokenization_auto import AutoTokenizer
|
||||
from .tokenization_bert import BasicTokenizer
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_utils_base import PaddingStrategy
|
||||
from .tokenization_utils_base import BatchEncoding, PaddingStrategy
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
@@ -51,6 +53,7 @@ if is_tf_available():
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
TFAutoModelForCausalLM,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
@@ -1895,6 +1898,321 @@ class TranslationPipeline(Pipeline):
|
||||
return results
|
||||
|
||||
|
||||
class Conversation:
|
||||
"""
|
||||
Utility class containing a conversation and its history. This class is meant to be used as an input to the
|
||||
:obj:`~transformers.ConversationalPipeline`. The conversation contains a number of utility function to manage the addition of new
|
||||
user input and generated model responses. A conversation needs to contain an unprocessed user input before being
|
||||
passed to the :obj:`~transformers.ConversationalPipeline`. This user input is either created when the class is instantiated, or by calling
|
||||
`append_response("input")` after a conversation turn.
|
||||
|
||||
Usage::
|
||||
|
||||
conversation = Conversation("Going to the movies tonight - any suggestions?")
|
||||
|
||||
# Steps usually performed by the model when generating a response:
|
||||
# 1. Mark the user input as processed (moved to the history)
|
||||
conversation.mark_processed()
|
||||
# 2. Append a mode response
|
||||
conversation.append_response("The Big lebowski.")
|
||||
|
||||
conversation.add_user_input("Is it good?")
|
||||
|
||||
Arguments:
|
||||
text (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
The initial user input to start the conversation.
|
||||
If :obj:`None`, a user input needs to be provided manually using `add_user_input` before the conversation can begin.
|
||||
conversation_id (:obj:`uuid.UUID`, `optional`, defaults to :obj:`None`):
|
||||
Unique identifier for the conversation
|
||||
If :obj:`None`, the random UUID4 id will be assigned to the conversation.
|
||||
"""
|
||||
|
||||
def __init__(self, text: str = None, conversation_id: UUID = None):
|
||||
if not conversation_id:
|
||||
conversation_id = uuid.uuid4()
|
||||
self.uuid: UUID = conversation_id
|
||||
self.past_user_inputs: List[str] = []
|
||||
self.generated_responses: List[str] = []
|
||||
self.history: List[int] = []
|
||||
self.new_user_input: Optional[str] = text
|
||||
|
||||
def add_user_input(self, text: str, overwrite: bool = False):
|
||||
"""
|
||||
Add a user input to the conversation for the next round. This populates the internal `new_user_input` field.
|
||||
|
||||
Args:
|
||||
text: str, the user input for the next conversation round
|
||||
overwrite: bool, flag indicating if existing and unprocessed user input should be overwritten when this function is called
|
||||
|
||||
"""
|
||||
if self.new_user_input:
|
||||
if overwrite:
|
||||
logger.warning(
|
||||
'User input added while unprocessed input was existing: "{}" was overwritten with: "{}".'.format(
|
||||
self.new_user_input, text
|
||||
)
|
||||
)
|
||||
self.new_user_input = text
|
||||
else:
|
||||
logger.warning(
|
||||
'User input added while unprocessed input was existing: "{}" new input ignored: "{}". '
|
||||
"Set `overwrite` to True to overwrite unprocessed user input".format(self.new_user_input, text)
|
||||
)
|
||||
else:
|
||||
self.new_user_input = text
|
||||
|
||||
def mark_processed(self):
|
||||
"""
|
||||
Mark the conversation as processed (moves the content of `new_user_input` to `past_user_inputs`) and empties the
|
||||
`new_user_input` field.
|
||||
"""
|
||||
if self.new_user_input:
|
||||
self.past_user_inputs.append(self.new_user_input)
|
||||
self.new_user_input = None
|
||||
|
||||
def append_response(self, response: str):
|
||||
"""
|
||||
Append a response to the list of generated responses.
|
||||
|
||||
Args:
|
||||
response: str, the model generated response
|
||||
"""
|
||||
self.generated_responses.append(response)
|
||||
|
||||
def set_history(self, history: List[int]):
|
||||
"""
|
||||
Updates the value of the history of the conversation. The history is represented by a list of `token_ids`. The
|
||||
history is used by the model to generate responses based on the previous conversation turns.
|
||||
|
||||
Args:
|
||||
history: (list of int), history of tokens provided and generated for this conversation
|
||||
"""
|
||||
self.history = history
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Generates a string representation of the conversation.
|
||||
|
||||
Return:
|
||||
:obj:`str` or :obj:`Dict`:
|
||||
|
||||
Example:
|
||||
Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114
|
||||
user >> Going to the movies tonight - any suggestions?
|
||||
bot >> The Big Lebowski
|
||||
"""
|
||||
output = "Conversation id: {} \n".format(self.uuid)
|
||||
for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses):
|
||||
output += "user >> {} \n".format(user_input)
|
||||
output += "bot >> {} \n".format(generated_response)
|
||||
if self.new_user_input is not None:
|
||||
output += "user >> {} \n".format(self.new_user_input)
|
||||
return output
|
||||
|
||||
|
||||
class ConversationalPipeline(Pipeline):
|
||||
"""
|
||||
Multi-turn conversational pipeline.
|
||||
|
||||
Usage::
|
||||
|
||||
conversational_pipeline = pipeline("conversational")
|
||||
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
conversation_2 = Conversation("What's the last book you have read?")
|
||||
|
||||
conversational_pipeline([conversation_1, conversation_2])
|
||||
|
||||
conversation_1.add_user_input("Is it an action movie?")
|
||||
conversation_2.add_user_input("What is the genre of this book?")
|
||||
|
||||
conversational_pipeline([conversation_1, conversation_2])
|
||||
|
||||
The models that this pipeline can use are models that have been fine-tuned on a multi-turn conversational task,
|
||||
currently: "microsoft/DialoGPT-small", "microsoft/DialoGPT-medium", "microsoft/DialoGPT-large"
|
||||
See the up-to-date list of available models on
|
||||
`huggingface.co/models <https://huggingface.co/models?filter=conversational>`__.
|
||||
|
||||
Arguments:
|
||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
||||
checkpoint identifier or an actual pre-trained model inheriting from
|
||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||
TensorFlow.
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
||||
:class:`~transformers.PreTrainedTokenizer`.
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||
Model card attributed to the model for this pipeline.
|
||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
|
||||
installed.
|
||||
If no framework is specified, will default to the one currently installed. If no framework is specified
|
||||
and both frameworks are installed, will default to PyTorch.
|
||||
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
|
||||
Reference to the object in charge of parsing supplied pipeline parameters.
|
||||
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
|
||||
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
|
||||
on the associated CUDA device id.
|
||||
"""
|
||||
|
||||
def __init__(self, min_length_for_response=32, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.tokenizer.eos_token_id is not None, "DialoguePipeline tokenizer should have an EOS token set"
|
||||
if self.tokenizer.pad_token_id is not None:
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
else:
|
||||
self.pad_token_id = self.tokenizer.eos_token_id
|
||||
self.min_length_for_response = min_length_for_response
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
conversations: Union[Conversation, List[Conversation]],
|
||||
clean_up_tokenization_spaces=True,
|
||||
**generate_kwargs
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
conversations: (list of :class:`~transformers.pipelines.Conversation`) Conversations to generate responses for
|
||||
**generate_kwargs: extra kwargs passed to `self.model.generate`_
|
||||
|
||||
Returns:
|
||||
list of conversations with updated generated responses for those containing a new user input
|
||||
"""
|
||||
|
||||
# Input validation
|
||||
if isinstance(conversations, list):
|
||||
for conversation in conversations:
|
||||
assert isinstance(
|
||||
conversation, Conversation
|
||||
), "DialoguePipeline expects a Conversation or list of Conversations as an input"
|
||||
if conversation.new_user_input is None:
|
||||
raise ValueError(
|
||||
"Conversation with UUID {} does not contain new user input to process. "
|
||||
"Add user inputs with the conversation's `add_user_input` method".format(
|
||||
type(conversation.uuid)
|
||||
)
|
||||
)
|
||||
assert (
|
||||
self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
|
||||
), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
|
||||
elif isinstance(conversations, Conversation):
|
||||
conversations = [conversations]
|
||||
else:
|
||||
raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input")
|
||||
|
||||
with self.device_placement():
|
||||
|
||||
inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations])
|
||||
histories = [conversation.history for conversation in conversations]
|
||||
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||
inputs = self._concat_inputs_history(inputs, histories, max_length)
|
||||
|
||||
if self.framework == "pt":
|
||||
inputs = self.ensure_tensor_on_device(**inputs)
|
||||
input_length = inputs["input_ids"].shape[-1]
|
||||
|
||||
elif self.framework == "tf":
|
||||
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
|
||||
|
||||
if input_length > 0.9 * max_length:
|
||||
logger.warning(
|
||||
"Longest conversation length: {} is bigger than 0.9 * max_length: {}. "
|
||||
"You might consider trimming the early phase of the conversation".format(input_length, max_length)
|
||||
)
|
||||
generated_responses = self.model.generate(
|
||||
inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
|
||||
)
|
||||
|
||||
cleaned_history = self._clean_padding_history(generated_responses)
|
||||
output = []
|
||||
for conversation_index, conversation in enumerate(conversations):
|
||||
conversation.mark_processed()
|
||||
conversation.generated_responses.append(
|
||||
self.tokenizer.decode(
|
||||
cleaned_history[conversation_index][input_length:],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
)
|
||||
conversation.set_history(cleaned_history[conversation_index])
|
||||
output.append(conversation)
|
||||
if len(output) == 1:
|
||||
return output[0]
|
||||
else:
|
||||
return output
|
||||
|
||||
def _parse_and_tokenize(self, *args, **kwargs):
|
||||
"""
|
||||
Parse arguments and tokenize, adding an EOS token at the end of the user input
|
||||
"""
|
||||
# Parse arguments
|
||||
inputs = self._args_parser(*args, **kwargs)
|
||||
inputs = self.tokenizer.batch_encode_plus(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
|
||||
for input in inputs:
|
||||
input.append(self.tokenizer.eos_token_id)
|
||||
return inputs
|
||||
|
||||
def _clean_padding_history(self, generated_tensor) -> List[List[int]]:
|
||||
"""
|
||||
Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as
|
||||
an input:
|
||||
- at the end of the concatenated history and new user input, so that all input to the model have the same
|
||||
length
|
||||
- at the end of the generated response, as some responses will be longer than others
|
||||
This method cleans up these padding token so that the history for each conversation is not impacted by the
|
||||
batching process.
|
||||
"""
|
||||
outputs = []
|
||||
for sequence in generated_tensor:
|
||||
sequence_tokens = []
|
||||
is_previous_pad = False
|
||||
for token in sequence:
|
||||
if token == self.pad_token_id:
|
||||
if is_previous_pad:
|
||||
continue
|
||||
else:
|
||||
is_previous_pad = True
|
||||
else:
|
||||
is_previous_pad = False
|
||||
if self.framework == "pt":
|
||||
sequence_tokens.append(token.item())
|
||||
else:
|
||||
sequence_tokens.append(int(token.numpy()))
|
||||
|
||||
outputs.append(sequence_tokens)
|
||||
return outputs
|
||||
|
||||
def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Optional[List[int]]], max_length: int):
|
||||
"""
|
||||
Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context
|
||||
"""
|
||||
outputs = []
|
||||
for new_input, history in zip(inputs, histories):
|
||||
if history is not None:
|
||||
new_input = history + new_input
|
||||
if len(new_input) > max_length - self.min_length_for_response:
|
||||
cutoff_eos_index = 0
|
||||
while len(new_input) - cutoff_eos_index > max_length - self.min_length_for_response:
|
||||
if cutoff_eos_index >= len(new_input):
|
||||
break
|
||||
cutoff_eos_index = new_input[cutoff_eos_index:].index(self.tokenizer.eos_token_id)
|
||||
if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1:
|
||||
break
|
||||
else:
|
||||
new_input = new_input[cutoff_eos_index + 1 :]
|
||||
outputs.append(new_input)
|
||||
max_len = max([len(item) for item in outputs])
|
||||
outputs = [output + [self.pad_token_id] * (max_len - len(output)) for output in outputs]
|
||||
outputs = BatchEncoding(
|
||||
{"input_ids": outputs, "attention_mask": [1] * len(outputs)}, tensor_type=self.framework
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
# Register all the supported tasks here
|
||||
SUPPORTED_TASKS = {
|
||||
"feature-extraction": {
|
||||
@@ -1979,6 +2297,12 @@ SUPPORTED_TASKS = {
|
||||
"tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
||||
},
|
||||
},
|
||||
"conversational": {
|
||||
"impl": ConversationalPipeline,
|
||||
"tf": TFAutoModelForCausalLM if is_tf_available() else None,
|
||||
"pt": AutoModelForCausalLM if is_torch_available() else None,
|
||||
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import SUPPORTED_TASKS, DefaultArgumentHandler, Pipeline
|
||||
from transformers.pipelines import SUPPORTED_TASKS, Conversation, DefaultArgumentHandler, Pipeline
|
||||
from transformers.testing_utils import require_tf, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ TRANSLATION_FINETUNED_MODELS = [
|
||||
]
|
||||
TF_TRANSLATION_FINETUNED_MODELS = [("patrickvonplaten/t5-tiny-random", "translation_en_to_fr")]
|
||||
|
||||
DIALOGUE_FINETUNED_MODELS = ["microsoft/DialoGPT-medium"]
|
||||
|
||||
expected_fill_mask_result = [
|
||||
[
|
||||
{"sequence": "<s>My name is John</s>", "score": 0.00782308354973793, "token": 610, "token_str": "ĠJohn"},
|
||||
@@ -314,6 +316,64 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
nlp = pipeline(task="text-generation", model=model_name, tokenizer=model_name, framework="tf")
|
||||
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_integration_torch_conversation(self):
|
||||
# When
|
||||
nlp = pipeline(task="conversational", device=DEFAULT_DEVICE_NUM)
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
conversation_2 = Conversation("What's the last book you have read?")
|
||||
# Then
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||
# When
|
||||
result = nlp([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||
# Then
|
||||
self.assertEqual(result, [conversation_1, conversation_2])
|
||||
self.assertEqual(len(result[0].past_user_inputs), 1)
|
||||
self.assertEqual(len(result[1].past_user_inputs), 1)
|
||||
self.assertEqual(len(result[0].generated_responses), 1)
|
||||
self.assertEqual(len(result[1].generated_responses), 1)
|
||||
self.assertEqual(result[0].past_user_inputs[0], "Going to the movies tonight - any suggestions?")
|
||||
self.assertEqual(result[0].generated_responses[0], "The Big Lebowski")
|
||||
self.assertEqual(result[1].past_user_inputs[0], "What's the last book you have read?")
|
||||
self.assertEqual(result[1].generated_responses[0], "The Last Question")
|
||||
# When
|
||||
conversation_2.add_user_input("Why do you recommend it?")
|
||||
result = nlp(conversation_2, do_sample=False, max_length=1000)
|
||||
# Then
|
||||
self.assertEqual(result, conversation_2)
|
||||
self.assertEqual(len(result.past_user_inputs), 2)
|
||||
self.assertEqual(len(result.generated_responses), 2)
|
||||
self.assertEqual(result.past_user_inputs[1], "Why do you recommend it?")
|
||||
self.assertEqual(result.generated_responses[1], "It's a good book.")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_integration_torch_conversation_truncated_history(self):
|
||||
# When
|
||||
nlp = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM)
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
# Then
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
# When
|
||||
result = nlp(conversation_1, do_sample=False, max_length=36)
|
||||
# Then
|
||||
self.assertEqual(result, conversation_1)
|
||||
self.assertEqual(len(result.past_user_inputs), 1)
|
||||
self.assertEqual(len(result.generated_responses), 1)
|
||||
self.assertEqual(result.past_user_inputs[0], "Going to the movies tonight - any suggestions?")
|
||||
self.assertEqual(result.generated_responses[0], "The Big Lebowski")
|
||||
# When
|
||||
conversation_1.add_user_input("Is it an action movie?")
|
||||
result = nlp(conversation_1, do_sample=False, max_length=36)
|
||||
# Then
|
||||
self.assertEqual(result, conversation_1)
|
||||
self.assertEqual(len(result.past_user_inputs), 2)
|
||||
self.assertEqual(len(result.generated_responses), 2)
|
||||
self.assertEqual(result.past_user_inputs[1], "Is it an action movie?")
|
||||
self.assertEqual(result.generated_responses[1], "It's a comedy.")
|
||||
|
||||
|
||||
QA_FINETUNED_MODELS = ["sshleifer/tiny-distilbert-base-cased-distilled-squad"]
|
||||
|
||||
@@ -450,6 +510,38 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase):
|
||||
self._test_zero_shot_pipeline_outputs(nlp)
|
||||
|
||||
|
||||
class DialoguePipelineTests(unittest.TestCase):
|
||||
def _test_conversation_pipeline(self, nlp):
|
||||
valid_inputs = [Conversation("Hi there!"), [Conversation("Hi there!"), Conversation("How are you?")]]
|
||||
invalid_inputs = ["Hi there!", Conversation()]
|
||||
self.assertIsNotNone(nlp)
|
||||
|
||||
mono_result = nlp(valid_inputs[0])
|
||||
self.assertIsInstance(mono_result, Conversation)
|
||||
|
||||
multi_result = nlp(valid_inputs[1])
|
||||
self.assertIsInstance(multi_result, list)
|
||||
self.assertIsInstance(multi_result[0], Conversation)
|
||||
# Inactive conversations passed to the pipeline raise a ValueError
|
||||
self.assertRaises(ValueError, nlp, valid_inputs[1])
|
||||
|
||||
for bad_input in invalid_inputs:
|
||||
self.assertRaises(Exception, nlp, bad_input)
|
||||
self.assertRaises(Exception, nlp, invalid_inputs)
|
||||
|
||||
@require_torch
|
||||
def test_torch_conversation(self):
|
||||
for model_name in DIALOGUE_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="conversational", model=model_name, tokenizer=model_name)
|
||||
self._test_conversation_pipeline(nlp)
|
||||
|
||||
@require_tf
|
||||
def test_tf_conversation(self):
|
||||
for model_name in DIALOGUE_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="conversational", model=model_name, tokenizer=model_name, framework="tf")
|
||||
self._test_conversation_pipeline(nlp)
|
||||
|
||||
|
||||
class QAPipelineTests(unittest.TestCase):
|
||||
def _test_qa_pipeline(self, nlp):
|
||||
output_keys = {"score", "answer", "start", "end"}
|
||||
@@ -593,7 +685,6 @@ class NerPipelineTests(unittest.TestCase):
|
||||
|
||||
|
||||
class PipelineCommonTests(unittest.TestCase):
|
||||
|
||||
pipelines = SUPPORTED_TASKS.keys()
|
||||
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user