From 14b04b4b9c483d94fadd2b5479ed9430bae8ac84 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 16 Oct 2023 17:27:45 +0100 Subject: [PATCH] Conversation pipeline fixes (#26795) * Adjust length limits and allow naked conversation list inputs * Adjust length limits and allow naked conversation list inputs * Maybe use a slightly more reasonable limit than 1024 * Skip tests for old models that never supported this anyway * Cleanup input docstrings * More docstring cleanup + skip failing TF test * Make fixup --- src/transformers/pipelines/conversational.py | 23 +++++++++----------- tests/models/bart/test_modeling_bart.py | 4 ++++ tests/models/bart/test_modeling_tf_bart.py | 4 ++++ tests/models/t5/test_modeling_t5.py | 4 ++++ tests/models/t5/test_modeling_tf_t5.py | 8 +++++++ 5 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/transformers/pipelines/conversational.py b/src/transformers/pipelines/conversational.py index ccfa36c4ec..2beaf8cc2e 100644 --- a/src/transformers/pipelines/conversational.py +++ b/src/transformers/pipelines/conversational.py @@ -247,13 +247,15 @@ class ConversationalPipeline(Pipeline): forward_params.update(generate_kwargs) return preprocess_params, forward_params, postprocess_params - def __call__(self, conversations: Union[Conversation, List[Conversation]], num_workers=0, **kwargs): + def __call__(self, conversations: Union[List[Dict], Conversation, List[Conversation]], num_workers=0, **kwargs): r""" Generate responses for the conversation(s) given as inputs. Args: conversations (a [`Conversation`] or a list of [`Conversation`]): - Conversations to generate responses for. + Conversation to generate responses for. Inputs can also be passed as a list of dictionaries with `role` + and `content` keys - in this case, they will be converted to `Conversation` objects automatically. + Multiple conversations in either format may be passed as a list. clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): Whether or not to clean up the potential extra spaces in the text output. generate_kwargs: @@ -268,6 +270,10 @@ class ConversationalPipeline(Pipeline): # Otherwise the threads will require a Conversation copy. # This will definitely hinder performance on GPU, but has to be opted # in because of this BC change. + if isinstance(conversations, list) and isinstance(conversations[0], dict): + conversations = Conversation(conversations) + elif isinstance(conversations, list) and isinstance(conversations[0], list): + conversations = [Conversation(conv) for conv in conversations] outputs = super().__call__(conversations, num_workers=num_workers, **kwargs) if isinstance(outputs, list) and len(outputs) == 1: return outputs[0] @@ -283,19 +289,10 @@ class ConversationalPipeline(Pipeline): return {"input_ids": input_ids, "conversation": conversation} def _forward(self, model_inputs, minimum_tokens=10, **generate_kwargs): - max_length = generate_kwargs.get("max_length", self.model.config.max_length) - n = model_inputs["input_ids"].shape[1] - if max_length - minimum_tokens < n: - logger.warning( - f"Conversation input is too long ({n}), trimming it to {max_length - minimum_tokens} tokens. Consider increasing `max_length` to avoid truncation." - ) - trim = max_length - minimum_tokens - model_inputs["input_ids"] = model_inputs["input_ids"][:, -trim:] - if "attention_mask" in model_inputs: - model_inputs["attention_mask"] = model_inputs["attention_mask"][:, -trim:] conversation = model_inputs.pop("conversation") - generate_kwargs["max_length"] = max_length + if "max_length" not in generate_kwargs and "max_new_tokens" not in generate_kwargs: + generate_kwargs["max_new_tokens"] = 256 output_ids = self.model.generate(**model_inputs, **generate_kwargs) if self.model.config.is_encoder_decoder: start_position = 1 diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index 01189e5628..d91ecf4cf5 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -507,6 +507,10 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin model.generate(input_ids, attention_mask=attention_mask) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + @unittest.skip("Does not support conversations.") + def test_pipeline_conversational(self): + pass + def assert_tensors_close(a, b, atol=1e-12, prefix=""): """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" diff --git a/tests/models/bart/test_modeling_tf_bart.py b/tests/models/bart/test_modeling_tf_bart.py index 05720f2978..60b35dcbec 100644 --- a/tests/models/bart/test_modeling_tf_bart.py +++ b/tests/models/bart/test_modeling_tf_bart.py @@ -343,6 +343,10 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester # check that the output for the restored model is the same self.assert_outputs_same(restored_model_outputs, outputs) + @unittest.skip("Does not support conversations.") + def test_pipeline_conversational(self): + pass + def _long_tensor(tok_lst): return tf.constant(tok_lst, dtype=tf.int32) diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index c94bfc1f11..68b9f45e15 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -891,6 +891,10 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, def test_disk_offload(self): pass + @unittest.skip("Does not support conversations.") + def test_pipeline_conversational(self): + pass + class T5EncoderOnlyModelTester: def __init__( diff --git a/tests/models/t5/test_modeling_tf_t5.py b/tests/models/t5/test_modeling_tf_t5.py index ec7488e4c3..9976e20baf 100644 --- a/tests/models/t5/test_modeling_tf_t5.py +++ b/tests/models/t5/test_modeling_tf_t5.py @@ -314,6 +314,10 @@ class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_keras_save_load(self): pass + @unittest.skip("Does not support conversations.") + def test_pipeline_conversational(self): + pass + class TFT5EncoderOnlyModelTester: def __init__( @@ -607,6 +611,10 @@ class TFT5GenerationIntegrationTests(unittest.TestCase): expected_output_string = ["Ich liebe es so sehr!", "die Transformatoren sind wirklich erstaunlich"] self.assertListEqual(expected_output_string, output_strings) + @unittest.skip("Does not support conversations.") + def test_pipeline_conversational(self): + pass + @require_tf @require_sentencepiece