Cleaning up ConversationalPipeline to support more than DialoGPT. (#10002)
* Cleaning up `ConversationalPipeline` to support more than DialoGPT. Currently ConversationalPipeline was heavily biased towards DialoGPT ,which is the default model for this pipeline. This PR proposes changes to put back the modifications specific to DialoGPT into tokenizer-specific behavior wherever possible, by creating `_build_conversation_input_ids` function that takes conversation as input, and returns a list of ints corresponding to the tokens. It feels natural to put here because all models have probably different strategies to build input_ids from the full conversation and it's the tokenizer's job to transform strings into tokens (and vice-versa) If `_build_conversation_input_ids` is missing, previous behavior is used so we don't break anything so far (except for blenderbot where it's a fix). This PR also contains a fix for too long inputs. There used to be dead code for trying to limit the size of incoming input. The introduced fixed is that we limit within `_build_conversation_input_ids` to `tokenizer.model_max_length`. It corresponds to the intent of the removed dead code and is actually better because it corresponds to `model_max_length` which is different from `max_length` (which is a default parameter for `generate`). - Removed `history` logic from the Conversation as it's not relevant anymore because tokenization logic has been moved to tokenizer. And tokenizer cannot save any cache, and conversation cannot know what is relevant or not. Also it's not usable from `blenderbot` because the input_ids are not append only (EOS tokens is always at the end). - Added `iter_texts` method on `Conversation` because all the code was literred with some form of this iteration of past/generated_responses. * Removing torch mention in types. * Adding type checking to `_build_conversation_input_ids`. * Fixing import in strings.
This commit is contained in:
@@ -14,12 +14,15 @@
|
||||
# limitations under the License.
|
||||
"""Tokenization class for Blenderbot."""
|
||||
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from ...utils import logging
|
||||
from ..roberta.tokenization_roberta import RobertaTokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -74,6 +77,23 @@ class BlenderbotTokenizer(RobertaTokenizer):
|
||||
"""
|
||||
return token_ids_0 + [self.eos_token_id]
|
||||
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
inputs = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
if is_user:
|
||||
# We need to space prefix as it's being done within blenderbot
|
||||
inputs.append(" " + text)
|
||||
else:
|
||||
# Generated responses should contain them already.
|
||||
inputs.append(text)
|
||||
|
||||
full_string = " ".join(inputs)
|
||||
input_ids = self.encode(full_string)
|
||||
if len(input_ids) > self.model_max_length:
|
||||
input_ids = input_ids[-self.model_max_length :]
|
||||
logger.warning(f"Trimmed input from conversation as it was longer than {self.model_max_length} tokens.")
|
||||
return input_ids
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import regex as re
|
||||
|
||||
@@ -26,6 +26,9 @@ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
@@ -296,3 +299,11 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
if is_split_into_words or add_prefix_space:
|
||||
text = " " + text
|
||||
return (text, kwargs)
|
||||
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
input_ids = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
|
||||
if len(input_ids) > self.model_max_length:
|
||||
input_ids = input_ids[-self.model_max_length :]
|
||||
return input_ids
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
|
||||
import json
|
||||
from typing import Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from tokenizers import pre_tokenizers
|
||||
|
||||
@@ -26,6 +26,10 @@ from ...utils import logging
|
||||
from .tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
|
||||
@@ -171,3 +175,13 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
||||
return tuple(files)
|
||||
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
"""This corresponds to DialoGPT variants of models."""
|
||||
input_ids = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
|
||||
|
||||
if len(input_ids) > self.model_max_length:
|
||||
input_ids = input_ids[-self.model_max_length :]
|
||||
return input_ids
|
||||
|
||||
Reference in New Issue
Block a user