Add chat templating support for KeyDataset in text-generation pipeline (#30558)

* added chat templating support for keydataset in generation pipeline

* fixed and improved test

* fix formatting test failures

* Fix tests

* Fix tests
This commit is contained in:
DarshanDeshpande
2024-04-30 11:51:41 -07:00
committed by GitHub
parent 0cdb6b3f92
commit 2ecefc3959
2 changed files with 48 additions and 2 deletions

View File

@@ -177,6 +177,48 @@ class TextGenerationPipelineTests(unittest.TestCase):
],
)
@require_torch
def test_small_chat_model_with_dataset_pt(self):
from torch.utils.data import Dataset
from transformers.pipelines.pt_utils import KeyDataset
class MyDataset(Dataset):
data = [
[
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is a reply"},
],
]
def __len__(self):
return 1
def __getitem__(self, i):
return {"text": self.data[i]}
text_generator = pipeline(
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
)
dataset = MyDataset()
key_dataset = KeyDataset(dataset, "text")
for outputs in text_generator(key_dataset, do_sample=False, max_new_tokens=10):
expected_chat = dataset.data[0] + [
{
"role": "assistant",
"content": " factors factors factors factors factors factors factors factors factors factors",
}
]
self.assertEqual(
outputs,
[
{"generated_text": expected_chat},
],
)
@require_tf
def test_small_model_tf(self):
text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")