Add chat templating support for KeyDataset in text-generation pipeline (#30558)
* added chat templating support for keydataset in generation pipeline * fixed and improved test * fix formatting test failures * Fix tests * Fix tests
This commit is contained in:
@@ -177,6 +177,48 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_chat_model_with_dataset_pt(self):
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers.pipelines.pt_utils import KeyDataset
|
||||
|
||||
class MyDataset(Dataset):
|
||||
data = [
|
||||
[
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
{"role": "assistant", "content": "This is a reply"},
|
||||
],
|
||||
]
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
def __getitem__(self, i):
|
||||
return {"text": self.data[i]}
|
||||
|
||||
text_generator = pipeline(
|
||||
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
)
|
||||
|
||||
dataset = MyDataset()
|
||||
key_dataset = KeyDataset(dataset, "text")
|
||||
|
||||
for outputs in text_generator(key_dataset, do_sample=False, max_new_tokens=10):
|
||||
expected_chat = dataset.data[0] + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": " factors factors factors factors factors factors factors factors factors factors",
|
||||
}
|
||||
]
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{"generated_text": expected_chat},
|
||||
],
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")
|
||||
|
||||
Reference in New Issue
Block a user