Conversation pipeline fixes (#26795)
* Adjust length limits and allow naked conversation list inputs * Adjust length limits and allow naked conversation list inputs * Maybe use a slightly more reasonable limit than 1024 * Skip tests for old models that never supported this anyway * Cleanup input docstrings * More docstring cleanup + skip failing TF test * Make fixup
This commit is contained in:
@@ -247,13 +247,15 @@ class ConversationalPipeline(Pipeline):
|
||||
forward_params.update(generate_kwargs)
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
def __call__(self, conversations: Union[Conversation, List[Conversation]], num_workers=0, **kwargs):
|
||||
def __call__(self, conversations: Union[List[Dict], Conversation, List[Conversation]], num_workers=0, **kwargs):
|
||||
r"""
|
||||
Generate responses for the conversation(s) given as inputs.
|
||||
|
||||
Args:
|
||||
conversations (a [`Conversation`] or a list of [`Conversation`]):
|
||||
Conversations to generate responses for.
|
||||
Conversation to generate responses for. Inputs can also be passed as a list of dictionaries with `role`
|
||||
and `content` keys - in this case, they will be converted to `Conversation` objects automatically.
|
||||
Multiple conversations in either format may be passed as a list.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
generate_kwargs:
|
||||
@@ -268,6 +270,10 @@ class ConversationalPipeline(Pipeline):
|
||||
# Otherwise the threads will require a Conversation copy.
|
||||
# This will definitely hinder performance on GPU, but has to be opted
|
||||
# in because of this BC change.
|
||||
if isinstance(conversations, list) and isinstance(conversations[0], dict):
|
||||
conversations = Conversation(conversations)
|
||||
elif isinstance(conversations, list) and isinstance(conversations[0], list):
|
||||
conversations = [Conversation(conv) for conv in conversations]
|
||||
outputs = super().__call__(conversations, num_workers=num_workers, **kwargs)
|
||||
if isinstance(outputs, list) and len(outputs) == 1:
|
||||
return outputs[0]
|
||||
@@ -283,19 +289,10 @@ class ConversationalPipeline(Pipeline):
|
||||
return {"input_ids": input_ids, "conversation": conversation}
|
||||
|
||||
def _forward(self, model_inputs, minimum_tokens=10, **generate_kwargs):
|
||||
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||
|
||||
n = model_inputs["input_ids"].shape[1]
|
||||
if max_length - minimum_tokens < n:
|
||||
logger.warning(
|
||||
f"Conversation input is too long ({n}), trimming it to {max_length - minimum_tokens} tokens. Consider increasing `max_length` to avoid truncation."
|
||||
)
|
||||
trim = max_length - minimum_tokens
|
||||
model_inputs["input_ids"] = model_inputs["input_ids"][:, -trim:]
|
||||
if "attention_mask" in model_inputs:
|
||||
model_inputs["attention_mask"] = model_inputs["attention_mask"][:, -trim:]
|
||||
conversation = model_inputs.pop("conversation")
|
||||
generate_kwargs["max_length"] = max_length
|
||||
if "max_length" not in generate_kwargs and "max_new_tokens" not in generate_kwargs:
|
||||
generate_kwargs["max_new_tokens"] = 256
|
||||
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
|
||||
if self.model.config.is_encoder_decoder:
|
||||
start_position = 1
|
||||
|
||||
@@ -507,6 +507,10 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
model.generate(input_ids, attention_mask=attention_mask)
|
||||
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
|
||||
|
||||
@@ -343,6 +343,10 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
|
||||
# check that the output for the restored model is the same
|
||||
self.assert_outputs_same(restored_model_outputs, outputs)
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
def _long_tensor(tok_lst):
|
||||
return tf.constant(tok_lst, dtype=tf.int32)
|
||||
|
||||
@@ -891,6 +891,10 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
def test_disk_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
class T5EncoderOnlyModelTester:
|
||||
def __init__(
|
||||
|
||||
@@ -314,6 +314,10 @@ class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
def test_keras_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
class TFT5EncoderOnlyModelTester:
|
||||
def __init__(
|
||||
@@ -607,6 +611,10 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
||||
expected_output_string = ["Ich liebe es so sehr!", "die Transformatoren sind wirklich erstaunlich"]
|
||||
self.assertListEqual(expected_output_string, output_strings)
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_sentencepiece
|
||||
|
||||
Reference in New Issue
Block a user