From 2ecefc3959705b3b4f696736fa1c5b44102f31c7 Mon Sep 17 00:00:00 2001 From: DarshanDeshpande <39432636+DarshanDeshpande@users.noreply.github.com> Date: Tue, 30 Apr 2024 11:51:41 -0700 Subject: [PATCH] Add chat templating support for KeyDataset in text-generation pipeline (#30558) * added chat templating support for keydataset in generation pipeline * fixed and improved test * fix formatting test failures * Fix tests * Fix tests --- src/transformers/pipelines/text_generation.py | 8 +++- .../test_pipelines_text_generation.py | 42 +++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 2f1ad71c78..ff5af53ff7 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -8,6 +8,7 @@ from .base import Pipeline, build_pipeline_init_args if is_torch_available(): from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + from .pt_utils import KeyDataset if is_tf_available(): import tensorflow as tf @@ -243,7 +244,9 @@ class TextGenerationPipeline(Pipeline): - **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token ids of the generated text. """ - if isinstance(text_inputs, (list, tuple)) and isinstance(text_inputs[0], (list, tuple, dict)): + if isinstance( + text_inputs, (list, tuple, KeyDataset) if is_torch_available() else (list, tuple) + ) and isinstance(text_inputs[0], (list, tuple, dict)): # We have one or more prompts in list-of-dicts format, so this is chat mode if isinstance(text_inputs[0], dict): return super().__call__(Chat(text_inputs), **kwargs) @@ -380,7 +383,8 @@ class TextGenerationPipeline(Pipeline): if isinstance(prompt_text, str): all_text = prompt_text + all_text elif isinstance(prompt_text, Chat): - all_text = prompt_text.messages + [{"role": "assistant", "content": all_text}] + # Explicit list parsing is necessary for parsing chat datasets + all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}] record = {"generated_text": all_text} records.append(record) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index ada04c7dbe..542f393b20 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -177,6 +177,48 @@ class TextGenerationPipelineTests(unittest.TestCase): ], ) + @require_torch + def test_small_chat_model_with_dataset_pt(self): + from torch.utils.data import Dataset + + from transformers.pipelines.pt_utils import KeyDataset + + class MyDataset(Dataset): + data = [ + [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a test"}, + {"role": "assistant", "content": "This is a reply"}, + ], + ] + + def __len__(self): + return 1 + + def __getitem__(self, i): + return {"text": self.data[i]} + + text_generator = pipeline( + task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt" + ) + + dataset = MyDataset() + key_dataset = KeyDataset(dataset, "text") + + for outputs in text_generator(key_dataset, do_sample=False, max_new_tokens=10): + expected_chat = dataset.data[0] + [ + { + "role": "assistant", + "content": " factors factors factors factors factors factors factors factors factors factors", + } + ] + self.assertEqual( + outputs, + [ + {"generated_text": expected_chat}, + ], + ) + @require_tf def test_small_model_tf(self): text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")