Add assistant prefill for chat templates and TextGenerationPipeline (#33198)

* Add assistant prefill to chat templates

* Add assistant prefill to pipeline

* Add assistant prefill to pipeline

* Tweak another test that ended in assistant message

* Update tests that ended in assistant messages

* Update tests that ended in assistant messages

* Replace assistant_prefill with continue_final_message

* Allow passing continue_final_message to pipeline

* Small fixup

* Add continue_final_message as a pipeline kwarg

* Update docstrings

* Move repos to hf-internal-testing!

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Add explanatory comment

* make fixup

* Update chat templating docs to explain continue_last_message

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Matt
2024-09-02 13:23:47 +01:00
committed by GitHub
parent 2d37085817
commit 52a0213755
6 changed files with 199 additions and 23 deletions

View File

@@ -877,8 +877,8 @@ class CustomPipelineTest(unittest.TestCase):
# See https://github.com/huggingface/transformers/issues/31669
text_generator = pipeline(
"text-generation",
model="Rocketknight1/fake-custom-model-test",
tokenizer="Rocketknight1/fake-custom-model-test",
model="hf-internal-testing/tiny-random-custom-architecture",
tokenizer="hf-internal-testing/tiny-random-custom-architecture",
trust_remote_code=True,
)
@@ -888,8 +888,8 @@ class CustomPipelineTest(unittest.TestCase):
def test_custom_code_with_string_feature_extractor(self):
speech_recognizer = pipeline(
"automatic-speech-recognition",
model="Rocketknight1/fake-custom-wav2vec2",
feature_extractor="Rocketknight1/fake-custom-wav2vec2",
model="hf-internal-testing/fake-custom-wav2vec2",
feature_extractor="hf-internal-testing/fake-custom-wav2vec2",
trust_remote_code=True,
)
@@ -899,8 +899,8 @@ class CustomPipelineTest(unittest.TestCase):
def test_custom_code_with_string_preprocessor(self):
mask_generator = pipeline(
"mask-generation",
model="Rocketknight1/fake-custom-sam",
processor="Rocketknight1/fake-custom-sam",
model="hf-internal-testing/fake-custom-sam",
processor="hf-internal-testing/fake-custom-sam",
trust_remote_code=True,
)

View File

@@ -148,18 +148,16 @@ class TextGenerationPipelineTests(unittest.TestCase):
@require_torch
def test_small_chat_model_pt(self):
text_generator = pipeline(
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
)
# Using `do_sample=False` to force deterministic output
chat1 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is a reply"},
]
chat2 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a second test"},
{"role": "assistant", "content": "This is a reply"},
]
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
expected_chat1 = chat1 + [
@@ -179,7 +177,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
expected_chat2 = chat2 + [
{
"role": "assistant",
"content": " factors factors factors factors factors factors factors factors factors factors",
"content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
}
]
@@ -191,6 +189,68 @@ class TextGenerationPipelineTests(unittest.TestCase):
],
)
@require_torch
def test_small_chat_model_continue_final_message(self):
# Here we check that passing a chat that ends in an assistant message is handled correctly
# by continuing the final message rather than starting a new one
text_generator = pipeline(
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
)
# Using `do_sample=False` to force deterministic output
chat1 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is"},
]
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
# Assert that we continued the last message and there isn't a sneaky <|im_end|>
self.assertEqual(
outputs,
[
{
"generated_text": [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{
"role": "assistant",
"content": "This is stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
},
]
}
],
)
@require_torch
def test_small_chat_model_continue_final_message_override(self):
# Here we check that passing a chat that ends in an assistant message is handled correctly
# by continuing the final message rather than starting a new one
text_generator = pipeline(
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
)
# Using `do_sample=False` to force deterministic output
chat1 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
]
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10, continue_final_message=True)
# Assert that we continued the last message and there isn't a sneaky <|im_end|>
self.assertEqual(
outputs,
[
{
"generated_text": [
{"role": "system", "content": "This is a system message."},
{
"role": "user",
"content": "This is a test stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
},
]
}
],
)
@require_torch
def test_small_chat_model_with_dataset_pt(self):
from torch.utils.data import Dataset
@@ -202,7 +262,6 @@ class TextGenerationPipelineTests(unittest.TestCase):
[
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is a reply"},
],
]
@@ -213,7 +272,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
return {"text": self.data[i]}
text_generator = pipeline(
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
)
dataset = MyDataset()
@@ -277,18 +336,16 @@ class TextGenerationPipelineTests(unittest.TestCase):
@require_tf
def test_small_chat_model_tf(self):
text_generator = pipeline(
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="tf"
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="tf"
)
# Using `do_sample=False` to force deterministic output
chat1 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is a reply"},
]
chat2 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a second test"},
{"role": "assistant", "content": "This is a reply"},
]
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
expected_chat1 = chat1 + [
@@ -308,7 +365,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
expected_chat2 = chat2 + [
{
"role": "assistant",
"content": " factors factors factors factors factors factors factors factors factors factors",
"content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
}
]