[ TokenizationUtils] Fix add_special_tokens when the token is already there (#28520)

* fix adding special tokens when the token is already there.

* add a test

* add a test

* nit

* fix the test: make sure the order is preserved

* Update tests/test_tokenization_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Arthur
2024-01-16 16:36:29 +01:00
committed by GitHub
parent 07ae53e6e7
commit 716df5fb7e
2 changed files with 30 additions and 4 deletions

View File

@@ -4139,3 +4139,28 @@ class TokenizerTesterMixin:
_test_added_vocab_and_eos(
EXPECTED_ADDED_TOKENS_DECODER, self.tokenizer_class, new_eos, tmp_dir_4
)
def test_special_token_addition(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
# Create tokenizer and add an additional special token
tokenizer_1 = tokenizer.from_pretrained(pretrained_name)
tokenizer_1.add_special_tokens({"additional_special_tokens": ["<tok>"]})
self.assertEqual(tokenizer_1.additional_special_tokens, ["<tok>"])
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer_1.save_pretrained(tmp_dir)
# Load the above tokenizer and add the same special token a second time
tokenizer_2 = tokenizer.from_pretrained(pretrained_name)
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>"]})
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>"])
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>", "<other>"]})
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>", "<other>"])
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<other>", "<another>"]})
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>"])
tokenizer_2.add_special_tokens(
{"additional_special_tokens": ["<tok>"]},
replace_additional_special_tokens=False,
)
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>", "<tok>"])