Add assistant prefill for chat templates and TextGenerationPipeline (#33198)
* Add assistant prefill to chat templates * Add assistant prefill to pipeline * Add assistant prefill to pipeline * Tweak another test that ended in assistant message * Update tests that ended in assistant messages * Update tests that ended in assistant messages * Replace assistant_prefill with continue_final_message * Allow passing continue_final_message to pipeline * Small fixup * Add continue_final_message as a pipeline kwarg * Update docstrings * Move repos to hf-internal-testing! * Update src/transformers/tokenization_utils_base.py Co-authored-by: Lysandre Debut <hi@lysand.re> * Add explanatory comment * make fixup * Update chat templating docs to explain continue_last_message --------- Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
@@ -877,8 +877,8 @@ class CustomPipelineTest(unittest.TestCase):
|
||||
# See https://github.com/huggingface/transformers/issues/31669
|
||||
text_generator = pipeline(
|
||||
"text-generation",
|
||||
model="Rocketknight1/fake-custom-model-test",
|
||||
tokenizer="Rocketknight1/fake-custom-model-test",
|
||||
model="hf-internal-testing/tiny-random-custom-architecture",
|
||||
tokenizer="hf-internal-testing/tiny-random-custom-architecture",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
@@ -888,8 +888,8 @@ class CustomPipelineTest(unittest.TestCase):
|
||||
def test_custom_code_with_string_feature_extractor(self):
|
||||
speech_recognizer = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model="Rocketknight1/fake-custom-wav2vec2",
|
||||
feature_extractor="Rocketknight1/fake-custom-wav2vec2",
|
||||
model="hf-internal-testing/fake-custom-wav2vec2",
|
||||
feature_extractor="hf-internal-testing/fake-custom-wav2vec2",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
@@ -899,8 +899,8 @@ class CustomPipelineTest(unittest.TestCase):
|
||||
def test_custom_code_with_string_preprocessor(self):
|
||||
mask_generator = pipeline(
|
||||
"mask-generation",
|
||||
model="Rocketknight1/fake-custom-sam",
|
||||
processor="Rocketknight1/fake-custom-sam",
|
||||
model="hf-internal-testing/fake-custom-sam",
|
||||
processor="hf-internal-testing/fake-custom-sam",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -148,18 +148,16 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
@require_torch
|
||||
def test_small_chat_model_pt(self):
|
||||
text_generator = pipeline(
|
||||
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
)
|
||||
# Using `do_sample=False` to force deterministic output
|
||||
chat1 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
{"role": "assistant", "content": "This is a reply"},
|
||||
]
|
||||
chat2 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a second test"},
|
||||
{"role": "assistant", "content": "This is a reply"},
|
||||
]
|
||||
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
|
||||
expected_chat1 = chat1 + [
|
||||
@@ -179,7 +177,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
expected_chat2 = chat2 + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": " factors factors factors factors factors factors factors factors factors factors",
|
||||
"content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
|
||||
}
|
||||
]
|
||||
|
||||
@@ -191,6 +189,68 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_chat_model_continue_final_message(self):
|
||||
# Here we check that passing a chat that ends in an assistant message is handled correctly
|
||||
# by continuing the final message rather than starting a new one
|
||||
text_generator = pipeline(
|
||||
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
)
|
||||
# Using `do_sample=False` to force deterministic output
|
||||
chat1 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
{"role": "assistant", "content": "This is"},
|
||||
]
|
||||
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
|
||||
|
||||
# Assert that we continued the last message and there isn't a sneaky <|im_end|>
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"generated_text": [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "This is stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
|
||||
},
|
||||
]
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_chat_model_continue_final_message_override(self):
|
||||
# Here we check that passing a chat that ends in an assistant message is handled correctly
|
||||
# by continuing the final message rather than starting a new one
|
||||
text_generator = pipeline(
|
||||
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
)
|
||||
# Using `do_sample=False` to force deterministic output
|
||||
chat1 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
]
|
||||
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10, continue_final_message=True)
|
||||
|
||||
# Assert that we continued the last message and there isn't a sneaky <|im_end|>
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"generated_text": [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "This is a test stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
|
||||
},
|
||||
]
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_chat_model_with_dataset_pt(self):
|
||||
from torch.utils.data import Dataset
|
||||
@@ -202,7 +262,6 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
[
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
{"role": "assistant", "content": "This is a reply"},
|
||||
],
|
||||
]
|
||||
|
||||
@@ -213,7 +272,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
return {"text": self.data[i]}
|
||||
|
||||
text_generator = pipeline(
|
||||
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
)
|
||||
|
||||
dataset = MyDataset()
|
||||
@@ -277,18 +336,16 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
@require_tf
|
||||
def test_small_chat_model_tf(self):
|
||||
text_generator = pipeline(
|
||||
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="tf"
|
||||
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="tf"
|
||||
)
|
||||
# Using `do_sample=False` to force deterministic output
|
||||
chat1 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
{"role": "assistant", "content": "This is a reply"},
|
||||
]
|
||||
chat2 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a second test"},
|
||||
{"role": "assistant", "content": "This is a reply"},
|
||||
]
|
||||
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
|
||||
expected_chat1 = chat1 + [
|
||||
@@ -308,7 +365,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
expected_chat2 = chat2 + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": " factors factors factors factors factors factors factors factors factors factors",
|
||||
"content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user