Addition of a DialoguePipeline (#5516)
* initial commit for pipeline implementation Addition of input processing and history concatenation * Conversation pipeline tested and working for single & multiple conversation inputs * Added docstrings for dialogue pipeline * Addition of dialogue pipeline integration tests * Delete test_t5.py * Fixed max code length * Updated styling * Fixed test broken by formatting tools * Removed unused import * Added unit test for DialoguePipeline * Fixed Tensorflow compatibility * Fixed multi-framework support using framework flag * - Fixed docstring - Added `min_length_for_response` as an initialization parameter - Renamed `*args` to `conversations`, `conversations` being a `Conversation` or a `List[Conversation]` - Updated truncation to truncate entire segments of conversations, instead of cutting in the middle of a user/bot input * - renamed pipeline name from dialogue to conversational - removed hardcoded default value of 1000 and use config.max_length instead - added `append_response` and `set_history` method to the Conversation class to avoid direct fields mutation - fixed bug in history truncation method * - Updated ConversationalPipeline to accept only active conversations (otherwise a ValueError is raised) * - Simplified input tensor conversion * - Updated attention_mask value for Tensorflow compatibility * - Updated last dialogue reference to conversational & fixed integration tests * Fixed conflict with master * Updates following review comments * Updated formatting * Added Conversation and ConversationalPipeline to the library __init__, addition of docstrings for Conversation, added both to the docs * Update src/transformers/pipelines.py Updated docsting following review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,7 @@ import unittest
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import SUPPORTED_TASKS, DefaultArgumentHandler, Pipeline
|
||||
from transformers.pipelines import SUPPORTED_TASKS, Conversation, DefaultArgumentHandler, Pipeline
|
||||
from transformers.testing_utils import require_tf, require_torch, slow, torch_device
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ TRANSLATION_FINETUNED_MODELS = [
|
||||
]
|
||||
TF_TRANSLATION_FINETUNED_MODELS = [("patrickvonplaten/t5-tiny-random", "translation_en_to_fr")]
|
||||
|
||||
DIALOGUE_FINETUNED_MODELS = ["microsoft/DialoGPT-medium"]
|
||||
|
||||
expected_fill_mask_result = [
|
||||
[
|
||||
{"sequence": "<s>My name is John</s>", "score": 0.00782308354973793, "token": 610, "token_str": "ĠJohn"},
|
||||
@@ -314,6 +316,64 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
nlp = pipeline(task="text-generation", model=model_name, tokenizer=model_name, framework="tf")
|
||||
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_integration_torch_conversation(self):
|
||||
# When
|
||||
nlp = pipeline(task="conversational", device=DEFAULT_DEVICE_NUM)
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
conversation_2 = Conversation("What's the last book you have read?")
|
||||
# Then
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||
# When
|
||||
result = nlp([conversation_1, conversation_2], do_sample=False, max_length=1000)
|
||||
# Then
|
||||
self.assertEqual(result, [conversation_1, conversation_2])
|
||||
self.assertEqual(len(result[0].past_user_inputs), 1)
|
||||
self.assertEqual(len(result[1].past_user_inputs), 1)
|
||||
self.assertEqual(len(result[0].generated_responses), 1)
|
||||
self.assertEqual(len(result[1].generated_responses), 1)
|
||||
self.assertEqual(result[0].past_user_inputs[0], "Going to the movies tonight - any suggestions?")
|
||||
self.assertEqual(result[0].generated_responses[0], "The Big Lebowski")
|
||||
self.assertEqual(result[1].past_user_inputs[0], "What's the last book you have read?")
|
||||
self.assertEqual(result[1].generated_responses[0], "The Last Question")
|
||||
# When
|
||||
conversation_2.add_user_input("Why do you recommend it?")
|
||||
result = nlp(conversation_2, do_sample=False, max_length=1000)
|
||||
# Then
|
||||
self.assertEqual(result, conversation_2)
|
||||
self.assertEqual(len(result.past_user_inputs), 2)
|
||||
self.assertEqual(len(result.generated_responses), 2)
|
||||
self.assertEqual(result.past_user_inputs[1], "Why do you recommend it?")
|
||||
self.assertEqual(result.generated_responses[1], "It's a good book.")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_integration_torch_conversation_truncated_history(self):
|
||||
# When
|
||||
nlp = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM)
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
# Then
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
# When
|
||||
result = nlp(conversation_1, do_sample=False, max_length=36)
|
||||
# Then
|
||||
self.assertEqual(result, conversation_1)
|
||||
self.assertEqual(len(result.past_user_inputs), 1)
|
||||
self.assertEqual(len(result.generated_responses), 1)
|
||||
self.assertEqual(result.past_user_inputs[0], "Going to the movies tonight - any suggestions?")
|
||||
self.assertEqual(result.generated_responses[0], "The Big Lebowski")
|
||||
# When
|
||||
conversation_1.add_user_input("Is it an action movie?")
|
||||
result = nlp(conversation_1, do_sample=False, max_length=36)
|
||||
# Then
|
||||
self.assertEqual(result, conversation_1)
|
||||
self.assertEqual(len(result.past_user_inputs), 2)
|
||||
self.assertEqual(len(result.generated_responses), 2)
|
||||
self.assertEqual(result.past_user_inputs[1], "Is it an action movie?")
|
||||
self.assertEqual(result.generated_responses[1], "It's a comedy.")
|
||||
|
||||
|
||||
QA_FINETUNED_MODELS = ["sshleifer/tiny-distilbert-base-cased-distilled-squad"]
|
||||
|
||||
@@ -450,6 +510,38 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase):
|
||||
self._test_zero_shot_pipeline_outputs(nlp)
|
||||
|
||||
|
||||
class DialoguePipelineTests(unittest.TestCase):
|
||||
def _test_conversation_pipeline(self, nlp):
|
||||
valid_inputs = [Conversation("Hi there!"), [Conversation("Hi there!"), Conversation("How are you?")]]
|
||||
invalid_inputs = ["Hi there!", Conversation()]
|
||||
self.assertIsNotNone(nlp)
|
||||
|
||||
mono_result = nlp(valid_inputs[0])
|
||||
self.assertIsInstance(mono_result, Conversation)
|
||||
|
||||
multi_result = nlp(valid_inputs[1])
|
||||
self.assertIsInstance(multi_result, list)
|
||||
self.assertIsInstance(multi_result[0], Conversation)
|
||||
# Inactive conversations passed to the pipeline raise a ValueError
|
||||
self.assertRaises(ValueError, nlp, valid_inputs[1])
|
||||
|
||||
for bad_input in invalid_inputs:
|
||||
self.assertRaises(Exception, nlp, bad_input)
|
||||
self.assertRaises(Exception, nlp, invalid_inputs)
|
||||
|
||||
@require_torch
|
||||
def test_torch_conversation(self):
|
||||
for model_name in DIALOGUE_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="conversational", model=model_name, tokenizer=model_name)
|
||||
self._test_conversation_pipeline(nlp)
|
||||
|
||||
@require_tf
|
||||
def test_tf_conversation(self):
|
||||
for model_name in DIALOGUE_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="conversational", model=model_name, tokenizer=model_name, framework="tf")
|
||||
self._test_conversation_pipeline(nlp)
|
||||
|
||||
|
||||
class QAPipelineTests(unittest.TestCase):
|
||||
def _test_qa_pipeline(self, nlp):
|
||||
output_keys = {"score", "answer", "start", "end"}
|
||||
@@ -593,7 +685,6 @@ class NerPipelineTests(unittest.TestCase):
|
||||
|
||||
|
||||
class PipelineCommonTests(unittest.TestCase):
|
||||
|
||||
pipelines = SUPPORTED_TASKS.keys()
|
||||
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user