Add assistant prefill for chat templates and TextGenerationPipeline (#33198)
* Add assistant prefill to chat templates * Add assistant prefill to pipeline * Add assistant prefill to pipeline * Tweak another test that ended in assistant message * Update tests that ended in assistant messages * Update tests that ended in assistant messages * Replace assistant_prefill with continue_final_message * Allow passing continue_final_message to pipeline * Small fixup * Add continue_final_message as a pipeline kwarg * Update docstrings * Move repos to hf-internal-testing! * Update src/transformers/tokenization_utils_base.py Co-authored-by: Lysandre Debut <hi@lysand.re> * Add explanatory comment * make fixup * Update chat templating docs to explain continue_last_message --------- Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
@@ -197,6 +197,43 @@ Not all models require generation prompts. Some models, like BlenderBot and LLaM
|
|||||||
special tokens before bot responses. In these cases, the `add_generation_prompt` argument will have no effect. The exact
|
special tokens before bot responses. In these cases, the `add_generation_prompt` argument will have no effect. The exact
|
||||||
effect that `add_generation_prompt` has will depend on the template being used.
|
effect that `add_generation_prompt` has will depend on the template being used.
|
||||||
|
|
||||||
|
## What does "continue_last_message" do?
|
||||||
|
|
||||||
|
When passing a list of messages to `apply_chat_template` or `TextGenerationPipeline`, you can choose
|
||||||
|
to format the chat so the model will continue the final message in the chat instead of starting a new one. This is done
|
||||||
|
by removing any end-of-sequence tokens that indicate the end of the final message, so that the model will simply
|
||||||
|
extend the final message when it begins to generate text. This is useful for "prefilling" the model's response.
|
||||||
|
|
||||||
|
Here's an example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
chat = [
|
||||||
|
{"role": "user", "content": "Can you format the answer in JSON?"},
|
||||||
|
{"role": "assistant", "content": '{"name": "'},
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_chat = tokenizer.apply_chat_template(chat, tokenize=True, return_dict=True, continue_last_message=True)
|
||||||
|
model.generate(**formatted_chat)
|
||||||
|
```
|
||||||
|
|
||||||
|
The model will generate text that continues the JSON string, rather than starting a new message. This approach
|
||||||
|
can be very useful for improving the accuracy of the model's instruction-following when you know how you want
|
||||||
|
it to start its replies.
|
||||||
|
|
||||||
|
Because `add_generation_prompt` adds the tokens that start a new message, and `continue_last_message` removes any
|
||||||
|
end-of-message tokens from the final message, it does not make sense to use them together. As a result, you'll
|
||||||
|
get an error if you try!
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
The default behaviour of `TextGenerationPipeline` is to set `add_generation_prompt=True` so that it starts a new
|
||||||
|
message. However, if the final message in the input chat has the "assistant" role, it will assume that this message is
|
||||||
|
a prefill and switch to `continue_final_message=True` instead, because most models do not support multiple
|
||||||
|
consecutive assistant messages. You can override this behaviour by explicitly passing the `continue_last_message`
|
||||||
|
argument when calling the pipeline.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
## Can I use chat templates in training?
|
## Can I use chat templates in training?
|
||||||
|
|
||||||
Yes! This is a good way to ensure that the chat template matches the tokens the model sees during training.
|
Yes! This is a good way to ensure that the chat template matches the tokens the model sees during training.
|
||||||
|
|||||||
@@ -131,6 +131,7 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
stop_sequence=None,
|
stop_sequence=None,
|
||||||
truncation=None,
|
truncation=None,
|
||||||
max_length=None,
|
max_length=None,
|
||||||
|
continue_final_message=None,
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
):
|
):
|
||||||
preprocess_params = {}
|
preprocess_params = {}
|
||||||
@@ -165,6 +166,9 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
)
|
)
|
||||||
preprocess_params["handle_long_generation"] = handle_long_generation
|
preprocess_params["handle_long_generation"] = handle_long_generation
|
||||||
|
|
||||||
|
if continue_final_message is not None:
|
||||||
|
preprocess_params["continue_final_message"] = continue_final_message
|
||||||
|
|
||||||
preprocess_params.update(generate_kwargs)
|
preprocess_params.update(generate_kwargs)
|
||||||
forward_params = generate_kwargs
|
forward_params = generate_kwargs
|
||||||
|
|
||||||
@@ -183,6 +187,8 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
postprocess_params["return_type"] = return_type
|
postprocess_params["return_type"] = return_type
|
||||||
if clean_up_tokenization_spaces is not None:
|
if clean_up_tokenization_spaces is not None:
|
||||||
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
|
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
|
||||||
|
if continue_final_message is not None:
|
||||||
|
postprocess_params["continue_final_message"] = continue_final_message
|
||||||
|
|
||||||
if stop_sequence is not None:
|
if stop_sequence is not None:
|
||||||
stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
|
stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
|
||||||
@@ -226,6 +232,10 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
*return_text* is set to True.
|
*return_text* is set to True.
|
||||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to clean up the potential extra spaces in the text output.
|
Whether or not to clean up the potential extra spaces in the text output.
|
||||||
|
continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the
|
||||||
|
last message in the input chat rather than starting a new one, allowing you to "prefill" its response.
|
||||||
|
By default this is `True` when the final message in the input chat has the `assistant` role and
|
||||||
|
`False` otherwise, but you can manually override that behaviour by setting this flag.
|
||||||
prefix (`str`, *optional*):
|
prefix (`str`, *optional*):
|
||||||
Prefix added to prompt.
|
Prefix added to prompt.
|
||||||
handle_long_generation (`str`, *optional*):
|
handle_long_generation (`str`, *optional*):
|
||||||
@@ -270,6 +280,7 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
truncation=None,
|
truncation=None,
|
||||||
padding=None,
|
padding=None,
|
||||||
max_length=None,
|
max_length=None,
|
||||||
|
continue_final_message=None,
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
):
|
):
|
||||||
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
|
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
|
||||||
@@ -283,9 +294,14 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
|
|
||||||
if isinstance(prompt_text, Chat):
|
if isinstance(prompt_text, Chat):
|
||||||
tokenizer_kwargs.pop("add_special_tokens", None) # ignore add_special_tokens on chats
|
tokenizer_kwargs.pop("add_special_tokens", None) # ignore add_special_tokens on chats
|
||||||
|
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
|
||||||
|
# because very few models support multiple separate, consecutive assistant messages
|
||||||
|
if continue_final_message is None:
|
||||||
|
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
|
||||||
inputs = self.tokenizer.apply_chat_template(
|
inputs = self.tokenizer.apply_chat_template(
|
||||||
prompt_text.messages,
|
prompt_text.messages,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=not continue_final_message,
|
||||||
|
continue_final_message=continue_final_message,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
return_tensors=self.framework,
|
return_tensors=self.framework,
|
||||||
**tokenizer_kwargs,
|
**tokenizer_kwargs,
|
||||||
@@ -356,7 +372,13 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
||||||
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
||||||
|
|
||||||
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
|
def postprocess(
|
||||||
|
self,
|
||||||
|
model_outputs,
|
||||||
|
return_type=ReturnType.FULL_TEXT,
|
||||||
|
clean_up_tokenization_spaces=True,
|
||||||
|
continue_final_message=None,
|
||||||
|
):
|
||||||
generated_sequence = model_outputs["generated_sequence"][0]
|
generated_sequence = model_outputs["generated_sequence"][0]
|
||||||
input_ids = model_outputs["input_ids"]
|
input_ids = model_outputs["input_ids"]
|
||||||
prompt_text = model_outputs["prompt_text"]
|
prompt_text = model_outputs["prompt_text"]
|
||||||
@@ -390,9 +412,21 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
if isinstance(prompt_text, str):
|
if isinstance(prompt_text, str):
|
||||||
all_text = prompt_text + all_text
|
all_text = prompt_text + all_text
|
||||||
elif isinstance(prompt_text, Chat):
|
elif isinstance(prompt_text, Chat):
|
||||||
# Explicit list parsing is necessary for parsing chat datasets
|
if continue_final_message is None:
|
||||||
|
# If the user passes a chat ending in an assistant message, we treat it as a prefill by
|
||||||
|
# default because very few models support multiple separate, consecutive assistant messages
|
||||||
|
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
|
||||||
|
if continue_final_message:
|
||||||
|
# With assistant prefill, concat onto the end of the last message
|
||||||
|
all_text = list(prompt_text.messages)[:-1] + [
|
||||||
|
{
|
||||||
|
"role": prompt_text.messages[-1]["role"],
|
||||||
|
"content": prompt_text.messages[-1]["content"] + all_text,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# When we're not starting from a prefill, the output is a new assistant message
|
||||||
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]
|
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]
|
||||||
|
|
||||||
record = {"generated_text": all_text}
|
record = {"generated_text": all_text}
|
||||||
records.append(record)
|
records.append(record)
|
||||||
|
|
||||||
|
|||||||
@@ -1704,6 +1704,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
documents: Optional[List[Dict[str, str]]] = None,
|
documents: Optional[List[Dict[str, str]]] = None,
|
||||||
chat_template: Optional[str] = None,
|
chat_template: Optional[str] = None,
|
||||||
add_generation_prompt: bool = False,
|
add_generation_prompt: bool = False,
|
||||||
|
continue_final_message: bool = False,
|
||||||
tokenize: bool = True,
|
tokenize: bool = True,
|
||||||
padding: bool = False,
|
padding: bool = False,
|
||||||
truncation: bool = False,
|
truncation: bool = False,
|
||||||
@@ -1737,10 +1738,16 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
chat_template (`str`, *optional*):
|
chat_template (`str`, *optional*):
|
||||||
A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
|
A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
|
||||||
argument, as the model's template will be used by default.
|
argument, as the model's template will be used by default.
|
||||||
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
|
add_generation_prompt (bool, *optional*):
|
||||||
the start of an assistant message. This is useful when you want to generate a response from the model.
|
If this is set, a prompt with the token(s) that indicate
|
||||||
|
the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
|
||||||
Note that this argument will be passed to the chat template, and so it must be supported in the
|
Note that this argument will be passed to the chat template, and so it must be supported in the
|
||||||
template for this argument to have any effect.
|
template for this argument to have any effect.
|
||||||
|
continue_final_message (bool, *optional*):
|
||||||
|
If this is set, the chat will be formatted so that the final
|
||||||
|
message in the chat is open-ended, without any EOS tokens. The model will continue this message
|
||||||
|
rather than starting a new one. This allows you to "prefill" part of
|
||||||
|
the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
|
||||||
tokenize (`bool`, defaults to `True`):
|
tokenize (`bool`, defaults to `True`):
|
||||||
Whether to tokenize the output. If `False`, the output will be a string.
|
Whether to tokenize the output. If `False`, the output will be a string.
|
||||||
padding (`bool`, defaults to `False`):
|
padding (`bool`, defaults to `False`):
|
||||||
@@ -1803,6 +1810,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
conversations = [conversation]
|
conversations = [conversation]
|
||||||
is_batched = False
|
is_batched = False
|
||||||
|
|
||||||
|
if continue_final_message:
|
||||||
|
if add_generation_prompt:
|
||||||
|
raise ValueError(
|
||||||
|
"continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."
|
||||||
|
)
|
||||||
|
if return_assistant_tokens_mask:
|
||||||
|
raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
|
||||||
|
|
||||||
# We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
|
# We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
|
||||||
if tools is not None:
|
if tools is not None:
|
||||||
tool_schemas = []
|
tool_schemas = []
|
||||||
@@ -1849,6 +1864,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
**template_kwargs,
|
**template_kwargs,
|
||||||
)
|
)
|
||||||
|
if continue_final_message:
|
||||||
|
final_message = chat[-1]["content"]
|
||||||
|
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)].rstrip()
|
||||||
rendered.append(rendered_chat)
|
rendered.append(rendered_chat)
|
||||||
|
|
||||||
if not is_batched:
|
if not is_batched:
|
||||||
|
|||||||
@@ -877,8 +877,8 @@ class CustomPipelineTest(unittest.TestCase):
|
|||||||
# See https://github.com/huggingface/transformers/issues/31669
|
# See https://github.com/huggingface/transformers/issues/31669
|
||||||
text_generator = pipeline(
|
text_generator = pipeline(
|
||||||
"text-generation",
|
"text-generation",
|
||||||
model="Rocketknight1/fake-custom-model-test",
|
model="hf-internal-testing/tiny-random-custom-architecture",
|
||||||
tokenizer="Rocketknight1/fake-custom-model-test",
|
tokenizer="hf-internal-testing/tiny-random-custom-architecture",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -888,8 +888,8 @@ class CustomPipelineTest(unittest.TestCase):
|
|||||||
def test_custom_code_with_string_feature_extractor(self):
|
def test_custom_code_with_string_feature_extractor(self):
|
||||||
speech_recognizer = pipeline(
|
speech_recognizer = pipeline(
|
||||||
"automatic-speech-recognition",
|
"automatic-speech-recognition",
|
||||||
model="Rocketknight1/fake-custom-wav2vec2",
|
model="hf-internal-testing/fake-custom-wav2vec2",
|
||||||
feature_extractor="Rocketknight1/fake-custom-wav2vec2",
|
feature_extractor="hf-internal-testing/fake-custom-wav2vec2",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -899,8 +899,8 @@ class CustomPipelineTest(unittest.TestCase):
|
|||||||
def test_custom_code_with_string_preprocessor(self):
|
def test_custom_code_with_string_preprocessor(self):
|
||||||
mask_generator = pipeline(
|
mask_generator = pipeline(
|
||||||
"mask-generation",
|
"mask-generation",
|
||||||
model="Rocketknight1/fake-custom-sam",
|
model="hf-internal-testing/fake-custom-sam",
|
||||||
processor="Rocketknight1/fake-custom-sam",
|
processor="hf-internal-testing/fake-custom-sam",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -148,18 +148,16 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
@require_torch
|
@require_torch
|
||||||
def test_small_chat_model_pt(self):
|
def test_small_chat_model_pt(self):
|
||||||
text_generator = pipeline(
|
text_generator = pipeline(
|
||||||
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
|
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
|
||||||
)
|
)
|
||||||
# Using `do_sample=False` to force deterministic output
|
# Using `do_sample=False` to force deterministic output
|
||||||
chat1 = [
|
chat1 = [
|
||||||
{"role": "system", "content": "This is a system message."},
|
{"role": "system", "content": "This is a system message."},
|
||||||
{"role": "user", "content": "This is a test"},
|
{"role": "user", "content": "This is a test"},
|
||||||
{"role": "assistant", "content": "This is a reply"},
|
|
||||||
]
|
]
|
||||||
chat2 = [
|
chat2 = [
|
||||||
{"role": "system", "content": "This is a system message."},
|
{"role": "system", "content": "This is a system message."},
|
||||||
{"role": "user", "content": "This is a second test"},
|
{"role": "user", "content": "This is a second test"},
|
||||||
{"role": "assistant", "content": "This is a reply"},
|
|
||||||
]
|
]
|
||||||
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
|
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
|
||||||
expected_chat1 = chat1 + [
|
expected_chat1 = chat1 + [
|
||||||
@@ -179,7 +177,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
expected_chat2 = chat2 + [
|
expected_chat2 = chat2 + [
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": " factors factors factors factors factors factors factors factors factors factors",
|
"content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -191,6 +189,68 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_chat_model_continue_final_message(self):
|
||||||
|
# Here we check that passing a chat that ends in an assistant message is handled correctly
|
||||||
|
# by continuing the final message rather than starting a new one
|
||||||
|
text_generator = pipeline(
|
||||||
|
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
|
||||||
|
)
|
||||||
|
# Using `do_sample=False` to force deterministic output
|
||||||
|
chat1 = [
|
||||||
|
{"role": "system", "content": "This is a system message."},
|
||||||
|
{"role": "user", "content": "This is a test"},
|
||||||
|
{"role": "assistant", "content": "This is"},
|
||||||
|
]
|
||||||
|
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
|
||||||
|
|
||||||
|
# Assert that we continued the last message and there isn't a sneaky <|im_end|>
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"generated_text": [
|
||||||
|
{"role": "system", "content": "This is a system message."},
|
||||||
|
{"role": "user", "content": "This is a test"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "This is stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_chat_model_continue_final_message_override(self):
|
||||||
|
# Here we check that passing a chat that ends in an assistant message is handled correctly
|
||||||
|
# by continuing the final message rather than starting a new one
|
||||||
|
text_generator = pipeline(
|
||||||
|
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
|
||||||
|
)
|
||||||
|
# Using `do_sample=False` to force deterministic output
|
||||||
|
chat1 = [
|
||||||
|
{"role": "system", "content": "This is a system message."},
|
||||||
|
{"role": "user", "content": "This is a test"},
|
||||||
|
]
|
||||||
|
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10, continue_final_message=True)
|
||||||
|
|
||||||
|
# Assert that we continued the last message and there isn't a sneaky <|im_end|>
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"generated_text": [
|
||||||
|
{"role": "system", "content": "This is a system message."},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "This is a test stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_chat_model_with_dataset_pt(self):
|
def test_small_chat_model_with_dataset_pt(self):
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
@@ -202,7 +262,6 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
[
|
[
|
||||||
{"role": "system", "content": "This is a system message."},
|
{"role": "system", "content": "This is a system message."},
|
||||||
{"role": "user", "content": "This is a test"},
|
{"role": "user", "content": "This is a test"},
|
||||||
{"role": "assistant", "content": "This is a reply"},
|
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -213,7 +272,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
return {"text": self.data[i]}
|
return {"text": self.data[i]}
|
||||||
|
|
||||||
text_generator = pipeline(
|
text_generator = pipeline(
|
||||||
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
|
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = MyDataset()
|
dataset = MyDataset()
|
||||||
@@ -277,18 +336,16 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
@require_tf
|
@require_tf
|
||||||
def test_small_chat_model_tf(self):
|
def test_small_chat_model_tf(self):
|
||||||
text_generator = pipeline(
|
text_generator = pipeline(
|
||||||
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="tf"
|
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="tf"
|
||||||
)
|
)
|
||||||
# Using `do_sample=False` to force deterministic output
|
# Using `do_sample=False` to force deterministic output
|
||||||
chat1 = [
|
chat1 = [
|
||||||
{"role": "system", "content": "This is a system message."},
|
{"role": "system", "content": "This is a system message."},
|
||||||
{"role": "user", "content": "This is a test"},
|
{"role": "user", "content": "This is a test"},
|
||||||
{"role": "assistant", "content": "This is a reply"},
|
|
||||||
]
|
]
|
||||||
chat2 = [
|
chat2 = [
|
||||||
{"role": "system", "content": "This is a system message."},
|
{"role": "system", "content": "This is a system message."},
|
||||||
{"role": "user", "content": "This is a second test"},
|
{"role": "user", "content": "This is a second test"},
|
||||||
{"role": "assistant", "content": "This is a reply"},
|
|
||||||
]
|
]
|
||||||
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
|
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
|
||||||
expected_chat1 = chat1 + [
|
expected_chat1 = chat1 + [
|
||||||
@@ -308,7 +365,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
expected_chat2 = chat2 + [
|
expected_chat2 = chat2 + [
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": " factors factors factors factors factors factors factors factors factors factors",
|
"content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -1327,6 +1327,36 @@ class TokenizerTesterMixin:
|
|||||||
[0] * (assistant_start2 - assistant_end - 1),
|
[0] * (assistant_start2 - assistant_end - 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_jinja
|
||||||
|
def test_continue_final_message(self):
|
||||||
|
dummy_template = """
|
||||||
|
{%- for message in messages %}
|
||||||
|
{{- "<|im_start|>" + message['role'] + "\n" + message['content'] + "<|im_end|>" + "\n"}}
|
||||||
|
{%- endfor %}"""
|
||||||
|
dummy_conversation = [
|
||||||
|
{"role": "system", "content": "system message"},
|
||||||
|
{"role": "user", "content": "user message"},
|
||||||
|
{"role": "assistant", "content": "assistant message"},
|
||||||
|
]
|
||||||
|
tokenizers = self.get_tokenizers()
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
|
output = tokenizer.apply_chat_template(
|
||||||
|
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=False
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
output,
|
||||||
|
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message<|im_end|>\n",
|
||||||
|
)
|
||||||
|
prefill_output = tokenizer.apply_chat_template(
|
||||||
|
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True
|
||||||
|
)
|
||||||
|
# Assert that the final message is unterminated
|
||||||
|
self.assertEqual(
|
||||||
|
prefill_output,
|
||||||
|
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
|
||||||
|
)
|
||||||
|
|
||||||
@require_jinja
|
@require_jinja
|
||||||
def test_chat_template_dict(self):
|
def test_chat_template_dict(self):
|
||||||
dummy_template_1 = "{{'a'}}"
|
dummy_template_1 = "{{'a'}}"
|
||||||
|
|||||||
Reference in New Issue
Block a user