Conversation pipeline fixes (#26795)
* Adjust length limits and allow naked conversation list inputs * Adjust length limits and allow naked conversation list inputs * Maybe use a slightly more reasonable limit than 1024 * Skip tests for old models that never supported this anyway * Cleanup input docstrings * More docstring cleanup + skip failing TF test * Make fixup
This commit is contained in:
@@ -247,13 +247,15 @@ class ConversationalPipeline(Pipeline):
|
|||||||
forward_params.update(generate_kwargs)
|
forward_params.update(generate_kwargs)
|
||||||
return preprocess_params, forward_params, postprocess_params
|
return preprocess_params, forward_params, postprocess_params
|
||||||
|
|
||||||
def __call__(self, conversations: Union[Conversation, List[Conversation]], num_workers=0, **kwargs):
|
def __call__(self, conversations: Union[List[Dict], Conversation, List[Conversation]], num_workers=0, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Generate responses for the conversation(s) given as inputs.
|
Generate responses for the conversation(s) given as inputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
conversations (a [`Conversation`] or a list of [`Conversation`]):
|
conversations (a [`Conversation`] or a list of [`Conversation`]):
|
||||||
Conversations to generate responses for.
|
Conversation to generate responses for. Inputs can also be passed as a list of dictionaries with `role`
|
||||||
|
and `content` keys - in this case, they will be converted to `Conversation` objects automatically.
|
||||||
|
Multiple conversations in either format may be passed as a list.
|
||||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to clean up the potential extra spaces in the text output.
|
Whether or not to clean up the potential extra spaces in the text output.
|
||||||
generate_kwargs:
|
generate_kwargs:
|
||||||
@@ -268,6 +270,10 @@ class ConversationalPipeline(Pipeline):
|
|||||||
# Otherwise the threads will require a Conversation copy.
|
# Otherwise the threads will require a Conversation copy.
|
||||||
# This will definitely hinder performance on GPU, but has to be opted
|
# This will definitely hinder performance on GPU, but has to be opted
|
||||||
# in because of this BC change.
|
# in because of this BC change.
|
||||||
|
if isinstance(conversations, list) and isinstance(conversations[0], dict):
|
||||||
|
conversations = Conversation(conversations)
|
||||||
|
elif isinstance(conversations, list) and isinstance(conversations[0], list):
|
||||||
|
conversations = [Conversation(conv) for conv in conversations]
|
||||||
outputs = super().__call__(conversations, num_workers=num_workers, **kwargs)
|
outputs = super().__call__(conversations, num_workers=num_workers, **kwargs)
|
||||||
if isinstance(outputs, list) and len(outputs) == 1:
|
if isinstance(outputs, list) and len(outputs) == 1:
|
||||||
return outputs[0]
|
return outputs[0]
|
||||||
@@ -283,19 +289,10 @@ class ConversationalPipeline(Pipeline):
|
|||||||
return {"input_ids": input_ids, "conversation": conversation}
|
return {"input_ids": input_ids, "conversation": conversation}
|
||||||
|
|
||||||
def _forward(self, model_inputs, minimum_tokens=10, **generate_kwargs):
|
def _forward(self, model_inputs, minimum_tokens=10, **generate_kwargs):
|
||||||
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
|
||||||
|
|
||||||
n = model_inputs["input_ids"].shape[1]
|
n = model_inputs["input_ids"].shape[1]
|
||||||
if max_length - minimum_tokens < n:
|
|
||||||
logger.warning(
|
|
||||||
f"Conversation input is too long ({n}), trimming it to {max_length - minimum_tokens} tokens. Consider increasing `max_length` to avoid truncation."
|
|
||||||
)
|
|
||||||
trim = max_length - minimum_tokens
|
|
||||||
model_inputs["input_ids"] = model_inputs["input_ids"][:, -trim:]
|
|
||||||
if "attention_mask" in model_inputs:
|
|
||||||
model_inputs["attention_mask"] = model_inputs["attention_mask"][:, -trim:]
|
|
||||||
conversation = model_inputs.pop("conversation")
|
conversation = model_inputs.pop("conversation")
|
||||||
generate_kwargs["max_length"] = max_length
|
if "max_length" not in generate_kwargs and "max_new_tokens" not in generate_kwargs:
|
||||||
|
generate_kwargs["max_new_tokens"] = 256
|
||||||
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
|
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
|
||||||
if self.model.config.is_encoder_decoder:
|
if self.model.config.is_encoder_decoder:
|
||||||
start_position = 1
|
start_position = 1
|
||||||
|
|||||||
@@ -507,6 +507,10 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
model.generate(input_ids, attention_mask=attention_mask)
|
model.generate(input_ids, attention_mask=attention_mask)
|
||||||
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||||
|
|
||||||
|
@unittest.skip("Does not support conversations.")
|
||||||
|
def test_pipeline_conversational(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
||||||
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
|
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
|
||||||
|
|||||||
@@ -343,6 +343,10 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
|
|||||||
# check that the output for the restored model is the same
|
# check that the output for the restored model is the same
|
||||||
self.assert_outputs_same(restored_model_outputs, outputs)
|
self.assert_outputs_same(restored_model_outputs, outputs)
|
||||||
|
|
||||||
|
@unittest.skip("Does not support conversations.")
|
||||||
|
def test_pipeline_conversational(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _long_tensor(tok_lst):
|
def _long_tensor(tok_lst):
|
||||||
return tf.constant(tok_lst, dtype=tf.int32)
|
return tf.constant(tok_lst, dtype=tf.int32)
|
||||||
|
|||||||
@@ -891,6 +891,10 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
def test_disk_offload(self):
|
def test_disk_offload(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Does not support conversations.")
|
||||||
|
def test_pipeline_conversational(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class T5EncoderOnlyModelTester:
|
class T5EncoderOnlyModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -314,6 +314,10 @@ class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
def test_keras_save_load(self):
|
def test_keras_save_load(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Does not support conversations.")
|
||||||
|
def test_pipeline_conversational(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TFT5EncoderOnlyModelTester:
|
class TFT5EncoderOnlyModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -607,6 +611,10 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
|||||||
expected_output_string = ["Ich liebe es so sehr!", "die Transformatoren sind wirklich erstaunlich"]
|
expected_output_string = ["Ich liebe es so sehr!", "die Transformatoren sind wirklich erstaunlich"]
|
||||||
self.assertListEqual(expected_output_string, output_strings)
|
self.assertListEqual(expected_output_string, output_strings)
|
||||||
|
|
||||||
|
@unittest.skip("Does not support conversations.")
|
||||||
|
def test_pipeline_conversational(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
|||||||
Reference in New Issue
Block a user