Cleaning up ConversationalPipeline to support more than DialoGPT. (#10002)
* Cleaning up `ConversationalPipeline` to support more than DialoGPT. Currently ConversationalPipeline was heavily biased towards DialoGPT ,which is the default model for this pipeline. This PR proposes changes to put back the modifications specific to DialoGPT into tokenizer-specific behavior wherever possible, by creating `_build_conversation_input_ids` function that takes conversation as input, and returns a list of ints corresponding to the tokens. It feels natural to put here because all models have probably different strategies to build input_ids from the full conversation and it's the tokenizer's job to transform strings into tokens (and vice-versa) If `_build_conversation_input_ids` is missing, previous behavior is used so we don't break anything so far (except for blenderbot where it's a fix). This PR also contains a fix for too long inputs. There used to be dead code for trying to limit the size of incoming input. The introduced fixed is that we limit within `_build_conversation_input_ids` to `tokenizer.model_max_length`. It corresponds to the intent of the removed dead code and is actually better because it corresponds to `model_max_length` which is different from `max_length` (which is a default parameter for `generate`). - Removed `history` logic from the Conversation as it's not relevant anymore because tokenization logic has been moved to tokenizer. And tokenizer cannot save any cache, and conversation cannot know what is relevant or not. Also it's not usable from `blenderbot` because the input_ids are not append only (EOS tokens is always at the end). - Added `iter_texts` method on `Conversation` because all the code was literred with some form of this iteration of past/generated_responses. * Removing torch mention in types. * Adding type checking to `_build_conversation_input_ids`. * Fixing import in strings.
This commit is contained in:
@@ -14,12 +14,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tokenization class for Blenderbot."""
|
"""Tokenization class for Blenderbot."""
|
||||||
|
|
||||||
from typing import List
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..roberta.tokenization_roberta import RobertaTokenizer
|
from ..roberta.tokenization_roberta import RobertaTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers.pipelines.conversational import Conversation
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -74,6 +77,23 @@ class BlenderbotTokenizer(RobertaTokenizer):
|
|||||||
"""
|
"""
|
||||||
return token_ids_0 + [self.eos_token_id]
|
return token_ids_0 + [self.eos_token_id]
|
||||||
|
|
||||||
|
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||||
|
inputs = []
|
||||||
|
for is_user, text in conversation.iter_texts():
|
||||||
|
if is_user:
|
||||||
|
# We need to space prefix as it's being done within blenderbot
|
||||||
|
inputs.append(" " + text)
|
||||||
|
else:
|
||||||
|
# Generated responses should contain them already.
|
||||||
|
inputs.append(text)
|
||||||
|
|
||||||
|
full_string = " ".join(inputs)
|
||||||
|
input_ids = self.encode(full_string)
|
||||||
|
if len(input_ids) > self.model_max_length:
|
||||||
|
input_ids = input_ids[-self.model_max_length :]
|
||||||
|
logger.warning(f"Trimmed input from conversation as it was longer than {self.model_max_length} tokens.")
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
def get_pairs(word):
|
def get_pairs(word):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
|
|
||||||
@@ -26,6 +26,9 @@ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
|||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers.pipelines.conversational import Conversation
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
VOCAB_FILES_NAMES = {
|
VOCAB_FILES_NAMES = {
|
||||||
@@ -296,3 +299,11 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
|||||||
if is_split_into_words or add_prefix_space:
|
if is_split_into_words or add_prefix_space:
|
||||||
text = " " + text
|
text = " " + text
|
||||||
return (text, kwargs)
|
return (text, kwargs)
|
||||||
|
|
||||||
|
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||||
|
input_ids = []
|
||||||
|
for is_user, text in conversation.iter_texts():
|
||||||
|
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
|
||||||
|
if len(input_ids) > self.model_max_length:
|
||||||
|
input_ids = input_ids[-self.model_max_length :]
|
||||||
|
return input_ids
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
from tokenizers import pre_tokenizers
|
from tokenizers import pre_tokenizers
|
||||||
|
|
||||||
@@ -26,6 +26,10 @@ from ...utils import logging
|
|||||||
from .tokenization_gpt2 import GPT2Tokenizer
|
from .tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers.pipelines.conversational import Conversation
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
|
||||||
@@ -171,3 +175,13 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
|
|||||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
||||||
return tuple(files)
|
return tuple(files)
|
||||||
|
|
||||||
|
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||||
|
"""This corresponds to DialoGPT variants of models."""
|
||||||
|
input_ids = []
|
||||||
|
for is_user, text in conversation.iter_texts():
|
||||||
|
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
|
||||||
|
|
||||||
|
if len(input_ids) > self.model_max_length:
|
||||||
|
input_ids = input_ids[-self.model_max_length :]
|
||||||
|
return input_ids
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
||||||
from ..tokenization_utils import TruncationStrategy
|
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||||
|
|
||||||
@@ -70,8 +69,6 @@ class Conversation:
|
|||||||
self.past_user_inputs: List[str] = past_user_inputs
|
self.past_user_inputs: List[str] = past_user_inputs
|
||||||
self.generated_responses: List[str] = generated_responses
|
self.generated_responses: List[str] = generated_responses
|
||||||
self.new_user_input: Optional[str] = text
|
self.new_user_input: Optional[str] = text
|
||||||
self._index: int = 0
|
|
||||||
self._history: List[int] = []
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if not isinstance(other, Conversation):
|
if not isinstance(other, Conversation):
|
||||||
@@ -128,6 +125,19 @@ class Conversation:
|
|||||||
"""
|
"""
|
||||||
self.generated_responses.append(response)
|
self.generated_responses.append(response)
|
||||||
|
|
||||||
|
def iter_texts(self):
|
||||||
|
"""
|
||||||
|
Iterates over all blobs of the conversation.
|
||||||
|
|
||||||
|
Retuns: Iterator of (is_user, text_chunk) in chronological order of the conversation. ``is_user`` is a
|
||||||
|
:obj:`bool`, ``text_chunks`` is a :obj:`str`.
|
||||||
|
"""
|
||||||
|
for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses):
|
||||||
|
yield True, user_input
|
||||||
|
yield False, generated_response
|
||||||
|
if self.new_user_input:
|
||||||
|
yield True, self.new_user_input
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""
|
"""
|
||||||
Generates a string representation of the conversation.
|
Generates a string representation of the conversation.
|
||||||
@@ -139,11 +149,9 @@ class Conversation:
|
|||||||
suggestions? bot >> The Big Lebowski
|
suggestions? bot >> The Big Lebowski
|
||||||
"""
|
"""
|
||||||
output = "Conversation id: {} \n".format(self.uuid)
|
output = "Conversation id: {} \n".format(self.uuid)
|
||||||
for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses):
|
for is_user, text in self.iter_texts():
|
||||||
output += "user >> {} \n".format(user_input)
|
name = "user" if is_user else "bot"
|
||||||
output += "bot >> {} \n".format(generated_response)
|
output += "{} >> {} \n".format(name, text)
|
||||||
if self.new_user_input is not None:
|
|
||||||
output += "user >> {} \n".format(self.new_user_input)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@@ -191,34 +199,6 @@ class ConversationalPipeline(Pipeline):
|
|||||||
|
|
||||||
self.min_length_for_response = min_length_for_response
|
self.min_length_for_response = min_length_for_response
|
||||||
|
|
||||||
def _get_history(self, conversation):
|
|
||||||
"""
|
|
||||||
Private function (subject to change) that simply tokenizes and concatenates past inputs. Also saves that
|
|
||||||
tokenization into the conversation state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation (:class:`~transformers.Conversation`)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:obj:`List[int]`: The list of tokens for the past input of that conversation.
|
|
||||||
"""
|
|
||||||
# Make a copy to prevent messing cache up if there's an error
|
|
||||||
# within this function
|
|
||||||
history = conversation._history.copy()
|
|
||||||
index = conversation._index
|
|
||||||
new_index = index
|
|
||||||
for i, (past_user_input, generated_response) in enumerate(
|
|
||||||
zip(conversation.past_user_inputs[index:], conversation.generated_responses[index:])
|
|
||||||
):
|
|
||||||
for el in (past_user_input, generated_response):
|
|
||||||
new_history = self._parse_and_tokenize([el])[0]
|
|
||||||
history.extend(new_history)
|
|
||||||
new_index = i + index + 1
|
|
||||||
conversation._index = new_index
|
|
||||||
conversation._history = history
|
|
||||||
# Hand back a copy to caller so they can't accidently modify our cache.
|
|
||||||
return history.copy()
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
conversations: Union[Conversation, List[Conversation]],
|
conversations: Union[Conversation, List[Conversation]],
|
||||||
@@ -249,7 +229,7 @@ class ConversationalPipeline(Pipeline):
|
|||||||
for conversation in conversations:
|
for conversation in conversations:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
conversation, Conversation
|
conversation, Conversation
|
||||||
), "DialoguePipeline expects a Conversation or list of Conversations as an input"
|
), "ConversationalPipeline expects a Conversation or list of Conversations as an input"
|
||||||
if conversation.new_user_input is None:
|
if conversation.new_user_input is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Conversation with UUID {} does not contain new user input to process. "
|
"Conversation with UUID {} does not contain new user input to process. "
|
||||||
@@ -261,14 +241,11 @@ class ConversationalPipeline(Pipeline):
|
|||||||
self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
|
self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
|
||||||
), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
|
), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
|
||||||
else:
|
else:
|
||||||
raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input")
|
raise ValueError("ConversationalPipeline expects a Conversation or list of Conversations as an input")
|
||||||
|
|
||||||
with self.device_placement():
|
with self.device_placement():
|
||||||
|
|
||||||
inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations])
|
inputs = self._parse_and_tokenize(conversations)
|
||||||
histories = [self._get_history(conversation) for conversation in conversations]
|
|
||||||
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
|
||||||
inputs = self._concat_inputs_history(inputs, histories, max_length)
|
|
||||||
|
|
||||||
if self.framework == "pt":
|
if self.framework == "pt":
|
||||||
inputs = self.ensure_tensor_on_device(**inputs)
|
inputs = self.ensure_tensor_on_device(**inputs)
|
||||||
@@ -277,11 +254,6 @@ class ConversationalPipeline(Pipeline):
|
|||||||
elif self.framework == "tf":
|
elif self.framework == "tf":
|
||||||
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
|
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
|
||||||
|
|
||||||
if input_length > 0.9 * max_length:
|
|
||||||
logger.warning(
|
|
||||||
"Longest conversation length: {} is bigger than 0.9 * max_length: {}. "
|
|
||||||
"You might consider trimming the early phase of the conversation".format(input_length, max_length)
|
|
||||||
)
|
|
||||||
generated_responses = self.model.generate(
|
generated_responses = self.model.generate(
|
||||||
inputs["input_ids"],
|
inputs["input_ids"],
|
||||||
attention_mask=inputs["attention_mask"],
|
attention_mask=inputs["attention_mask"],
|
||||||
@@ -318,18 +290,6 @@ class ConversationalPipeline(Pipeline):
|
|||||||
else:
|
else:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _parse_and_tokenize(
|
|
||||||
self, inputs, add_special_tokens=False, padding=False, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Parse arguments and tokenize, adding an EOS token at the end of the user input
|
|
||||||
"""
|
|
||||||
# Parse arguments
|
|
||||||
inputs = self.tokenizer(inputs, add_special_tokens=add_special_tokens, padding=padding).get("input_ids", [])
|
|
||||||
for input in inputs:
|
|
||||||
input.append(self.tokenizer.eos_token_id)
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
def _clean_padding_history(self, generated_tensor) -> List[List[int]]:
|
def _clean_padding_history(self, generated_tensor) -> List[List[int]]:
|
||||||
"""
|
"""
|
||||||
Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as
|
Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as
|
||||||
@@ -363,28 +323,23 @@ class ConversationalPipeline(Pipeline):
|
|||||||
outputs.append(sequence_tokens)
|
outputs.append(sequence_tokens)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Optional[List[int]]], max_length: int):
|
def _legacy_parse_and_tokenize(self, conversation: List[Conversation]) -> List[int]:
|
||||||
"""
|
eos_token_id = self.tokenizer.eos_token_id
|
||||||
Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context
|
input_ids = []
|
||||||
"""
|
for is_user, text in conversation.iter_texts():
|
||||||
outputs = []
|
input_ids.extend(self.tokenizer.encode(text, add_special_tokens=False) + [eos_token_id])
|
||||||
for new_input, history in zip(inputs, histories):
|
|
||||||
if history is not None:
|
if len(input_ids) > self.tokenizer.model_max_length:
|
||||||
new_input = history + new_input
|
input_ids = input_ids[-self.model_max_length :]
|
||||||
if len(new_input) > max_length - self.min_length_for_response:
|
return input_ids
|
||||||
cutoff_eos_index = 0
|
|
||||||
while len(new_input) - cutoff_eos_index > max_length - self.min_length_for_response:
|
def _parse_and_tokenize(self, conversations: List[Conversation]) -> Dict[str, Any]:
|
||||||
if cutoff_eos_index >= len(new_input):
|
if hasattr(self.tokenizer, "_build_conversation_input_ids"):
|
||||||
break
|
input_ids = [self.tokenizer._build_conversation_input_ids(conversation) for conversation in conversations]
|
||||||
cutoff_eos_index = new_input[cutoff_eos_index:].index(self.tokenizer.eos_token_id)
|
else:
|
||||||
if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1:
|
# If the tokenizer cannot handle conversations, we default to only the old version
|
||||||
break
|
input_ids = [self._legacy_parse_and_tokenize(conversation) for conversation in conversations]
|
||||||
else:
|
inputs = self.tokenizer.pad(
|
||||||
logger.warning(
|
{"input_ids": input_ids}, padding="longest", return_attention_mask=True, return_tensors="pt"
|
||||||
f"Cutting history off because it's too long ({len(new_input)} > {max_length - self.min_length_for_response}) for underlying model"
|
|
||||||
)
|
|
||||||
outputs.append(new_input)
|
|
||||||
padded_outputs = self.tokenizer.pad(
|
|
||||||
{"input_ids": outputs}, padding="longest", return_attention_mask=True, return_tensors=self.framework
|
|
||||||
)
|
)
|
||||||
return padded_outputs
|
return inputs
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
Conversation,
|
Conversation,
|
||||||
@@ -87,11 +88,7 @@ class SimpleConversationPipelineTests(unittest.TestCase):
|
|||||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||||
|
|
||||||
with self.assertLogs("transformers", level="WARNING") as log:
|
result = conversation_agent([conversation_1, conversation_2], max_length=48)
|
||||||
result = conversation_agent([conversation_1, conversation_2], max_length=48)
|
|
||||||
self.assertEqual(len(log.output), 2)
|
|
||||||
self.assertIn("You might consider trimming the early phase of the conversation", log.output[0])
|
|
||||||
self.assertIn("Setting `pad_token_id`", log.output[1])
|
|
||||||
|
|
||||||
# Two conversations in one pass
|
# Two conversations in one pass
|
||||||
self.assertEqual(result, [conversation_1, conversation_2])
|
self.assertEqual(result, [conversation_1, conversation_2])
|
||||||
@@ -111,12 +108,7 @@ class SimpleConversationPipelineTests(unittest.TestCase):
|
|||||||
|
|
||||||
# One conversation with history
|
# One conversation with history
|
||||||
conversation_2.add_user_input("Why do you recommend it?")
|
conversation_2.add_user_input("Why do you recommend it?")
|
||||||
with self.assertLogs("transformers", level="WARNING") as log:
|
result = conversation_agent(conversation_2, max_length=64)
|
||||||
result = conversation_agent(conversation_2, max_length=64)
|
|
||||||
self.assertEqual(len(log.output), 3)
|
|
||||||
self.assertIn("Cutting history off because it's too long", log.output[0])
|
|
||||||
self.assertIn("You might consider trimming the early phase of the conversation", log.output[1])
|
|
||||||
self.assertIn("Setting `pad_token_id`", log.output[2])
|
|
||||||
|
|
||||||
self.assertEqual(result, conversation_2)
|
self.assertEqual(result, conversation_2)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@@ -128,65 +120,6 @@ class SimpleConversationPipelineTests(unittest.TestCase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_history_cache(self):
|
|
||||||
conversation_agent = self.get_pipeline()
|
|
||||||
conversation = Conversation(
|
|
||||||
"Why do you recommend it?",
|
|
||||||
past_user_inputs=["What's the last book you have read?"],
|
|
||||||
generated_responses=["b"],
|
|
||||||
)
|
|
||||||
with self.assertLogs("transformers", level="WARNING") as log:
|
|
||||||
_ = conversation_agent(conversation, max_length=64)
|
|
||||||
self.assertEqual(len(log.output), 3)
|
|
||||||
self.assertIn("Cutting history off because it's too long (63 > 32) for underlying model", log.output[0])
|
|
||||||
self.assertIn("63 is bigger than 0.9 * max_length: 64", log.output[1])
|
|
||||||
self.assertIn("Setting `pad_token_id`", log.output[2])
|
|
||||||
self.assertEqual(conversation._index, 1)
|
|
||||||
self.assertEqual(
|
|
||||||
conversation._history,
|
|
||||||
[
|
|
||||||
87,
|
|
||||||
104,
|
|
||||||
97,
|
|
||||||
116,
|
|
||||||
39,
|
|
||||||
115,
|
|
||||||
32,
|
|
||||||
116,
|
|
||||||
104,
|
|
||||||
101,
|
|
||||||
32,
|
|
||||||
108,
|
|
||||||
97,
|
|
||||||
115,
|
|
||||||
116,
|
|
||||||
32,
|
|
||||||
98,
|
|
||||||
111,
|
|
||||||
111,
|
|
||||||
107,
|
|
||||||
32,
|
|
||||||
121,
|
|
||||||
111,
|
|
||||||
117,
|
|
||||||
32,
|
|
||||||
104,
|
|
||||||
97,
|
|
||||||
118,
|
|
||||||
101,
|
|
||||||
32,
|
|
||||||
114,
|
|
||||||
101,
|
|
||||||
97,
|
|
||||||
100,
|
|
||||||
63,
|
|
||||||
259, # EOS
|
|
||||||
98, # b
|
|
||||||
259, # EOS
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||||
pipeline_task = "conversational"
|
pipeline_task = "conversational"
|
||||||
@@ -276,6 +209,102 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
|||||||
self.assertEqual(result.past_user_inputs[1], "Is it an action movie?")
|
self.assertEqual(result.past_user_inputs[1], "Is it an action movie?")
|
||||||
self.assertEqual(result.generated_responses[1], "It's a comedy.")
|
self.assertEqual(result.generated_responses[1], "It's a comedy.")
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_integration_torch_conversation_dialogpt_input_ids(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
|
||||||
|
nlp = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
conversation_1 = Conversation("hello")
|
||||||
|
inputs = nlp._parse_and_tokenize([conversation_1])
|
||||||
|
self.assertEqual(inputs["input_ids"].tolist(), [[31373, 50256]])
|
||||||
|
|
||||||
|
conversation_2 = Conversation("how are you ?", past_user_inputs=["hello"], generated_responses=["Hi there!"])
|
||||||
|
inputs = nlp._parse_and_tokenize([conversation_2])
|
||||||
|
self.assertEqual(
|
||||||
|
inputs["input_ids"].tolist(), [[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256]]
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = nlp._parse_and_tokenize([conversation_1, conversation_2])
|
||||||
|
self.assertEqual(
|
||||||
|
inputs["input_ids"].tolist(),
|
||||||
|
[
|
||||||
|
[31373, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
|
||||||
|
[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_integration_torch_conversation_blenderbot_400M_input_ids(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
|
||||||
|
nlp = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
# test1
|
||||||
|
conversation_1 = Conversation("hello")
|
||||||
|
inputs = nlp._parse_and_tokenize([conversation_1])
|
||||||
|
self.assertEqual(inputs["input_ids"].tolist(), [[1710, 86, 2]])
|
||||||
|
|
||||||
|
# test2
|
||||||
|
conversation_1 = Conversation(
|
||||||
|
"I like lasagne.",
|
||||||
|
past_user_inputs=["hello"],
|
||||||
|
generated_responses=[
|
||||||
|
" Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie."
|
||||||
|
],
|
||||||
|
)
|
||||||
|
inputs = nlp._parse_and_tokenize([conversation_1])
|
||||||
|
self.assertEqual(
|
||||||
|
inputs["input_ids"].tolist(),
|
||||||
|
[
|
||||||
|
# This should be compared with the same conversation on ParlAI `safe_interactive` demo.
|
||||||
|
[
|
||||||
|
1710, # hello
|
||||||
|
86,
|
||||||
|
228, # Double space
|
||||||
|
228,
|
||||||
|
946,
|
||||||
|
304,
|
||||||
|
398,
|
||||||
|
6881,
|
||||||
|
558,
|
||||||
|
964,
|
||||||
|
38,
|
||||||
|
452,
|
||||||
|
315,
|
||||||
|
265,
|
||||||
|
6252,
|
||||||
|
452,
|
||||||
|
322,
|
||||||
|
968,
|
||||||
|
6884,
|
||||||
|
3146,
|
||||||
|
278,
|
||||||
|
306,
|
||||||
|
265,
|
||||||
|
617,
|
||||||
|
87,
|
||||||
|
388,
|
||||||
|
75,
|
||||||
|
341,
|
||||||
|
286,
|
||||||
|
521,
|
||||||
|
21,
|
||||||
|
228, # Double space
|
||||||
|
228,
|
||||||
|
281, # I like lasagne.
|
||||||
|
398,
|
||||||
|
6881,
|
||||||
|
558,
|
||||||
|
964,
|
||||||
|
21,
|
||||||
|
2, # EOS
|
||||||
|
]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_integration_torch_conversation_blenderbot_400M(self):
|
def test_integration_torch_conversation_blenderbot_400M(self):
|
||||||
@@ -295,11 +324,11 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
|||||||
" Hello! How are you doing today? I just got back from a walk with my dog.",
|
" Hello! How are you doing today? I just got back from a walk with my dog.",
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_1 = Conversation(" Lasagne hello")
|
conversation_1 = Conversation("Lasagne hello")
|
||||||
result = nlp(conversation_1, encoder_no_repeat_ngram_size=3)
|
result = nlp(conversation_1, encoder_no_repeat_ngram_size=3)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result.generated_responses[0],
|
result.generated_responses[0],
|
||||||
" Lasagne is my favorite Italian dish. Do you like lasagne?",
|
" Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie.",
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_1 = Conversation(
|
conversation_1 = Conversation(
|
||||||
@@ -311,10 +340,7 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
|
|||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result.generated_responses[0],
|
result.generated_responses[0],
|
||||||
# ParlAI implementation output, we have a different one, but it's our
|
" Me too. I like how it can be topped with vegetables, meats, and condiments.",
|
||||||
# second best, you can check by using num_return_sequences=10
|
|
||||||
# " Hello! How are you? I'm just getting ready to go to work, how about you?",
|
|
||||||
" Lasagne is a traditional Italian dish consisting of a yeasted flatbread typically topped with tomato sauce and cheese.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user