Add chat templating support for KeyDataset in text-generation pipeline (#30558)
* added chat templating support for keydataset in generation pipeline * fixed and improved test * fix formatting test failures * Fix tests * Fix tests
This commit is contained in:
@@ -8,6 +8,7 @@ from .base import Pipeline, build_pipeline_init_args
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
|
from .pt_utils import KeyDataset
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -243,7 +244,9 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
- **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
|
- **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
|
||||||
ids of the generated text.
|
ids of the generated text.
|
||||||
"""
|
"""
|
||||||
if isinstance(text_inputs, (list, tuple)) and isinstance(text_inputs[0], (list, tuple, dict)):
|
if isinstance(
|
||||||
|
text_inputs, (list, tuple, KeyDataset) if is_torch_available() else (list, tuple)
|
||||||
|
) and isinstance(text_inputs[0], (list, tuple, dict)):
|
||||||
# We have one or more prompts in list-of-dicts format, so this is chat mode
|
# We have one or more prompts in list-of-dicts format, so this is chat mode
|
||||||
if isinstance(text_inputs[0], dict):
|
if isinstance(text_inputs[0], dict):
|
||||||
return super().__call__(Chat(text_inputs), **kwargs)
|
return super().__call__(Chat(text_inputs), **kwargs)
|
||||||
@@ -380,7 +383,8 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
if isinstance(prompt_text, str):
|
if isinstance(prompt_text, str):
|
||||||
all_text = prompt_text + all_text
|
all_text = prompt_text + all_text
|
||||||
elif isinstance(prompt_text, Chat):
|
elif isinstance(prompt_text, Chat):
|
||||||
all_text = prompt_text.messages + [{"role": "assistant", "content": all_text}]
|
# Explicit list parsing is necessary for parsing chat datasets
|
||||||
|
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]
|
||||||
|
|
||||||
record = {"generated_text": all_text}
|
record = {"generated_text": all_text}
|
||||||
records.append(record)
|
records.append(record)
|
||||||
|
|||||||
@@ -177,6 +177,48 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_chat_model_with_dataset_pt(self):
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from transformers.pipelines.pt_utils import KeyDataset
|
||||||
|
|
||||||
|
class MyDataset(Dataset):
|
||||||
|
data = [
|
||||||
|
[
|
||||||
|
{"role": "system", "content": "This is a system message."},
|
||||||
|
{"role": "user", "content": "This is a test"},
|
||||||
|
{"role": "assistant", "content": "This is a reply"},
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return {"text": self.data[i]}
|
||||||
|
|
||||||
|
text_generator = pipeline(
|
||||||
|
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = MyDataset()
|
||||||
|
key_dataset = KeyDataset(dataset, "text")
|
||||||
|
|
||||||
|
for outputs in text_generator(key_dataset, do_sample=False, max_new_tokens=10):
|
||||||
|
expected_chat = dataset.data[0] + [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": " factors factors factors factors factors factors factors factors factors factors",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
{"generated_text": expected_chat},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_small_model_tf(self):
|
def test_small_model_tf(self):
|
||||||
text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")
|
text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")
|
||||||
|
|||||||
Reference in New Issue
Block a user