apply_chat_template: consistent behaviour for return_assistant_tokens_mask=True return_tensors=True (#35582)

* apply_chat_template: consistent return_tensors behaviour with return_assistant_tokens_mask flag

* test_chat_template_return_assistant_tokens_mask: support tokenizers with no attention mask

* test_chat_template_return_assistant_tokens_mask: skip tokenizers with no padding token

* test_chat_template_return_assistant_tokens_mask: force tokenizer padding_side=right

---------

Co-authored-by: Eduard Allakhverdov <goncharova@airi.net>
Co-authored-by: d.tarasov <d.tarasov@airi.net>
This commit is contained in:
Dmitry Tarasov
2025-02-04 12:27:52 +03:00
committed by GitHub
parent 9c02cb6233
commit 2ba040a71f
2 changed files with 61 additions and 1 deletions

View File

@@ -62,6 +62,7 @@ from transformers.tokenization_utils import AddedToken
if is_torch_available():
import torch
import torch.nn as nn
@@ -1219,6 +1220,7 @@ class TokenizerTesterMixin:
self.assertEqual(len(strftime_output), 10)
self.assertEqual(len(strftime_output.split("-")), 3)
@require_torch
@require_jinja
def test_chat_template_return_assistant_tokens_mask(self):
dummy_template = (
@@ -1263,6 +1265,9 @@ class TokenizerTesterMixin:
self.skipTest(reason="No fast tokenizer defined")
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name)
self._check_no_pad_token_padding(tokenizer_r, conversations)
tokenizer_r.padding_side = "right"
# check batched
output = tokenizer_r.apply_chat_template(
@@ -1272,6 +1277,20 @@ class TokenizerTesterMixin:
return_assistant_tokens_mask=True,
return_dict=True,
)
output_pt = tokenizer_r.apply_chat_template(
conversations,
chat_template=dummy_template,
tokenize=True,
padding=True,
return_assistant_tokens_mask=True,
return_dict=True,
return_tensors="pt",
)
self.assertEqual(type(output_pt["assistant_masks"]), torch.Tensor)
self.assertEqual(output_pt["assistant_masks"].shape, output_pt["input_ids"].shape)
for i, conv in enumerate(conversations):
chat_string = tokenizer_r.apply_chat_template(
conversations[i], tokenize=False, chat_template=dummy_template
@@ -1297,18 +1316,30 @@ class TokenizerTesterMixin:
output["assistant_masks"][i][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
self.assertTrue(
(output_pt["assistant_masks"][i, assistant_start : assistant_end + 1] == 1).all(),
)
# assert 1 second assistant message
self.assertEqual(
output["assistant_masks"][i][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)
self.assertTrue(
(output_pt["assistant_masks"][i, assistant_start2 : assistant_end2 + 1] == 1).all(),
)
# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][i][:assistant_start], [0] * assistant_start)
self.assertTrue((output_pt["assistant_masks"][i, :assistant_start] == 0).all())
self.assertEqual(
output["assistant_masks"][i][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)
self.assertTrue(
(output_pt["assistant_masks"][i, assistant_end + 1 : assistant_start2] == 0).all(),
)
# check not batched
output = tokenizer_r.apply_chat_template(
@@ -1318,6 +1349,17 @@ class TokenizerTesterMixin:
return_assistant_tokens_mask=True,
return_dict=True,
)
output_pt = tokenizer_r.apply_chat_template(
conversations[0],
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
return_tensors="pt",
)
self.assertEqual(type(output_pt["assistant_masks"]), torch.Tensor)
self.assertEqual(output_pt["assistant_masks"].shape, output_pt["input_ids"].shape)
chat_string = tokenizer_r.apply_chat_template(
conversations[0], tokenize=False, chat_template=dummy_template
@@ -1336,17 +1378,27 @@ class TokenizerTesterMixin:
output["assistant_masks"][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
self.assertTrue(
(output_pt["assistant_masks"][assistant_start : assistant_end + 1] == 1).all(),
)
self.assertEqual(
output["assistant_masks"][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)
self.assertTrue(
(output_pt["assistant_masks"][assistant_start2 : assistant_end2 + 1] == 1).all(),
)
# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][:assistant_start], [0] * assistant_start)
self.assertTrue((output_pt["assistant_masks"][0, :assistant_start] == 0).all())
self.assertEqual(
output["assistant_masks"][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)
self.assertTrue(
(output_pt["assistant_masks"][0, assistant_end + 1 : assistant_start2] == 0).all(),
)
@require_jinja
def test_chat_template_return_assistant_tokens_mask_truncated(self):