Making Conversation possible to create directly a full conversation (#9434)
* Cleaning up conversation tests. * Adding tests that don't require downloading models + conversation can be fully created from static state. * Making tests non flaky (by fixing generation length) * Bumping isort version. * Doc cleanup. * Remove unused test in this PR. * Torch import guard for TF. * Missing torch guard. * Small mistake in doc. * Actual uses `_history` and `_index` cache. + remove dead enumerate + improve warning message. * Update src/transformers/pipelines/conversational.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/pipelines/conversational.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/pipelines/conversational.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding comments and cleaner code to address history copy. * Improving pipeline name in tests. * Change tokenizer to a real one (still created at runtime with no external dependency) * Simplify DummyTok, reverse changes on tokenization. * Removing DummyTok. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -33,6 +33,14 @@ class Conversation:
|
||||
conversation_id (:obj:`uuid.UUID`, `optional`):
|
||||
Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the
|
||||
conversation.
|
||||
past_user_inputs (:obj:`List[str]`, `optional`):
|
||||
Eventual past history of the conversation of the user. You don't need to pass it manually if you use the
|
||||
pipeline interactively but if you want to recreate history you need to set both :obj:`past_user_inputs` and
|
||||
:obj:`generated_responses` with equal length lists of strings
|
||||
generated_responses (:obj:`List[str]`, `optional`):
|
||||
Eventual past history of the conversation of the model. You don't need to pass it manually if you use the
|
||||
pipeline interactively but if you want to recreate history you need to set both :obj:`past_user_inputs` and
|
||||
:obj:`generated_responses` with equal length lists of strings
|
||||
|
||||
Usage::
|
||||
|
||||
@@ -47,14 +55,33 @@ class Conversation:
|
||||
conversation.add_user_input("Is it good?")
|
||||
"""
|
||||
|
||||
def __init__(self, text: str = None, conversation_id: uuid.UUID = None):
|
||||
def __init__(
|
||||
self, text: str = None, conversation_id: uuid.UUID = None, past_user_inputs=None, generated_responses=None
|
||||
):
|
||||
if not conversation_id:
|
||||
conversation_id = uuid.uuid4()
|
||||
if past_user_inputs is None:
|
||||
past_user_inputs = []
|
||||
if generated_responses is None:
|
||||
generated_responses = []
|
||||
|
||||
self.uuid: uuid.UUID = conversation_id
|
||||
self.past_user_inputs: List[str] = []
|
||||
self.generated_responses: List[str] = []
|
||||
self.history: List[int] = []
|
||||
self.past_user_inputs: List[str] = past_user_inputs
|
||||
self.generated_responses: List[str] = generated_responses
|
||||
self.new_user_input: Optional[str] = text
|
||||
self._index: int = 0
|
||||
self._history: List[int] = []
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Conversation):
|
||||
return False
|
||||
if self.uuid == other.uuid:
|
||||
return True
|
||||
return (
|
||||
self.new_user_input == other.new_user_input
|
||||
and self.past_user_inputs == other.past_user_inputs
|
||||
and self.generated_responses == other.generated_responses
|
||||
)
|
||||
|
||||
def add_user_input(self, text: str, overwrite: bool = False):
|
||||
"""
|
||||
@@ -100,16 +127,6 @@ class Conversation:
|
||||
"""
|
||||
self.generated_responses.append(response)
|
||||
|
||||
def set_history(self, history: List[int]):
|
||||
"""
|
||||
Updates the value of the history of the conversation. The history is represented by a list of :obj:`token_ids`.
|
||||
The history is used by the model to generate responses based on the previous conversation turns.
|
||||
|
||||
Args:
|
||||
history (:obj:`List[int]`): History of tokens provided and generated for this conversation.
|
||||
"""
|
||||
self.history = history
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Generates a string representation of the conversation.
|
||||
@@ -167,12 +184,40 @@ class ConversationalPipeline(Pipeline):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# We need at least an eos_token
|
||||
assert self.tokenizer.eos_token_id is not None, "DialoguePipeline tokenizer should have an EOS token set"
|
||||
assert self.tokenizer.eos_token_id is not None, "ConversationalPipeline tokenizer should have an EOS token set"
|
||||
if self.tokenizer.pad_token_id is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.min_length_for_response = min_length_for_response
|
||||
|
||||
def _get_history(self, conversation):
|
||||
"""
|
||||
Private function (subject to change) that simply tokenizes and concatenates past inputs. Also saves that
|
||||
tokenization into the conversation state.
|
||||
|
||||
Args:
|
||||
conversation (:class:`~transformers.Conversation`)
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: The list of tokens for the past input of that conversation.
|
||||
"""
|
||||
# Make a copy to prevent messing cache up if there's an error
|
||||
# within this function
|
||||
history = conversation._history.copy()
|
||||
index = conversation._index
|
||||
new_index = index
|
||||
for i, (past_user_input, generated_response) in enumerate(
|
||||
zip(conversation.past_user_inputs[index:], conversation.generated_responses[index:])
|
||||
):
|
||||
for el in (past_user_input, generated_response):
|
||||
new_history = self._parse_and_tokenize([el])[0]
|
||||
history.extend(new_history)
|
||||
new_index = i + index + 1
|
||||
conversation._index = new_index
|
||||
conversation._history = history
|
||||
# Hand back a copy to caller so they can't accidently modify our cache.
|
||||
return history.copy()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
conversations: Union[Conversation, List[Conversation]],
|
||||
@@ -220,7 +265,7 @@ class ConversationalPipeline(Pipeline):
|
||||
with self.device_placement():
|
||||
|
||||
inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations])
|
||||
histories = [conversation.history for conversation in conversations]
|
||||
histories = [self._get_history(conversation) for conversation in conversations]
|
||||
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||
inputs = self._concat_inputs_history(inputs, histories, max_length)
|
||||
|
||||
@@ -266,7 +311,6 @@ class ConversationalPipeline(Pipeline):
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
)
|
||||
conversation.set_history(history[conversation_index])
|
||||
output.append(conversation)
|
||||
if len(output) == 1:
|
||||
return output[0]
|
||||
@@ -333,7 +377,9 @@ class ConversationalPipeline(Pipeline):
|
||||
if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1:
|
||||
break
|
||||
else:
|
||||
new_input = new_input[cutoff_eos_index + 1 :]
|
||||
logger.warning(
|
||||
f"Cutting history off because it's too long ({len(new_input)} > {max_length - self.min_length_for_response}) for underlying model"
|
||||
)
|
||||
outputs.append(new_input)
|
||||
padded_outputs = self.tokenizer.pad(
|
||||
{"input_ids": outputs}, padding="longest", return_attention_mask=True, return_tensors=self.framework
|
||||
|
||||
Reference in New Issue
Block a user