support chat generator as input of TextGenerationPipeline (#35551)
* support chat generator as input of TextGenerationPipeline * missing import * fix tests * again * simpler * add test
This commit is contained in:
@@ -292,6 +292,50 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_chat_model_with_iterator_pt(self):
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
|
||||
text_generator = pipeline(
|
||||
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
)
|
||||
|
||||
# Using `do_sample=False` to force deterministic output
|
||||
chat1 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
]
|
||||
chat2 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a second test"},
|
||||
]
|
||||
expected_chat1 = chat1 + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": " factors factors factors factors factors factors factors factors factors factors",
|
||||
}
|
||||
]
|
||||
expected_chat2 = chat2 + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
|
||||
}
|
||||
]
|
||||
|
||||
def data():
|
||||
yield from [chat1, chat2]
|
||||
|
||||
outputs = text_generator(data(), do_sample=False, max_new_tokens=10)
|
||||
assert isinstance(outputs, PipelineIterator)
|
||||
outputs = list(outputs)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"generated_text": expected_chat1}],
|
||||
[{"generated_text": expected_chat2}],
|
||||
],
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")
|
||||
|
||||
Reference in New Issue
Block a user