[ TokenizationUtils] Fix add_special_tokens when the token is already there (#28520)

* fix adding special tokens when the token is already there.

* add a test

* add a test

* nit

* fix the test: make sure the order is preserved

* Update tests/test_tokenization_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Arthur
2024-01-16 16:36:29 +01:00
committed by GitHub
parent 07ae53e6e7
commit 716df5fb7e
2 changed files with 30 additions and 4 deletions

View File

@@ -943,14 +943,15 @@ class SpecialTokensMixin:
isinstance(t, (str, AddedToken)) for t in value isinstance(t, (str, AddedToken)) for t in value
), f"Tokens {value} for key {key} should all be str or AddedToken instances" ), f"Tokens {value} for key {key} should all be str or AddedToken instances"
to_add = set() to_add = []
for token in value: for token in value:
if isinstance(token, str): if isinstance(token, str):
# for legacy purpose we default to stripping. `test_add_tokens_tokenizer` depends on this # for legacy purpose we default to stripping. `test_add_tokens_tokenizer` depends on this
token = AddedToken(token, rstrip=False, lstrip=False, normalized=False, special=True) token = AddedToken(token, rstrip=False, lstrip=False, normalized=False, special=True)
if str(token) not in self.additional_special_tokens: if not replace_additional_special_tokens and str(token) in self.additional_special_tokens:
to_add.add(token) continue
if replace_additional_special_tokens: to_add.append(token)
if replace_additional_special_tokens and len(to_add) > 0:
setattr(self, key, list(to_add)) setattr(self, key, list(to_add))
else: else:
self._additional_special_tokens.extend(to_add) self._additional_special_tokens.extend(to_add)

View File

@@ -4139,3 +4139,28 @@ class TokenizerTesterMixin:
_test_added_vocab_and_eos( _test_added_vocab_and_eos(
EXPECTED_ADDED_TOKENS_DECODER, self.tokenizer_class, new_eos, tmp_dir_4 EXPECTED_ADDED_TOKENS_DECODER, self.tokenizer_class, new_eos, tmp_dir_4
) )
def test_special_token_addition(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
# Create tokenizer and add an additional special token
tokenizer_1 = tokenizer.from_pretrained(pretrained_name)
tokenizer_1.add_special_tokens({"additional_special_tokens": ["<tok>"]})
self.assertEqual(tokenizer_1.additional_special_tokens, ["<tok>"])
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer_1.save_pretrained(tmp_dir)
# Load the above tokenizer and add the same special token a second time
tokenizer_2 = tokenizer.from_pretrained(pretrained_name)
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>"]})
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>"])
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>", "<other>"]})
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>", "<other>"])
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<other>", "<another>"]})
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>"])
tokenizer_2.add_special_tokens(
{"additional_special_tokens": ["<tok>"]},
replace_additional_special_tokens=False,
)
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>", "<tok>"])