[ TokenizationUtils] Fix add_special_tokens when the token is already there (#28520)
* fix adding special tokens when the token is already there. * add a test * add a test * nit * fix the test: make sure the order is preserved * Update tests/test_tokenization_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -943,14 +943,15 @@ class SpecialTokensMixin:
|
||||
isinstance(t, (str, AddedToken)) for t in value
|
||||
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
|
||||
|
||||
to_add = set()
|
||||
to_add = []
|
||||
for token in value:
|
||||
if isinstance(token, str):
|
||||
# for legacy purpose we default to stripping. `test_add_tokens_tokenizer` depends on this
|
||||
token = AddedToken(token, rstrip=False, lstrip=False, normalized=False, special=True)
|
||||
if str(token) not in self.additional_special_tokens:
|
||||
to_add.add(token)
|
||||
if replace_additional_special_tokens:
|
||||
if not replace_additional_special_tokens and str(token) in self.additional_special_tokens:
|
||||
continue
|
||||
to_add.append(token)
|
||||
if replace_additional_special_tokens and len(to_add) > 0:
|
||||
setattr(self, key, list(to_add))
|
||||
else:
|
||||
self._additional_special_tokens.extend(to_add)
|
||||
|
||||
@@ -4139,3 +4139,28 @@ class TokenizerTesterMixin:
|
||||
_test_added_vocab_and_eos(
|
||||
EXPECTED_ADDED_TOKENS_DECODER, self.tokenizer_class, new_eos, tmp_dir_4
|
||||
)
|
||||
|
||||
def test_special_token_addition(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
# Create tokenizer and add an additional special token
|
||||
tokenizer_1 = tokenizer.from_pretrained(pretrained_name)
|
||||
tokenizer_1.add_special_tokens({"additional_special_tokens": ["<tok>"]})
|
||||
self.assertEqual(tokenizer_1.additional_special_tokens, ["<tok>"])
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer_1.save_pretrained(tmp_dir)
|
||||
# Load the above tokenizer and add the same special token a second time
|
||||
tokenizer_2 = tokenizer.from_pretrained(pretrained_name)
|
||||
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>"]})
|
||||
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>"])
|
||||
|
||||
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>", "<other>"]})
|
||||
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>", "<other>"])
|
||||
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<other>", "<another>"]})
|
||||
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>"])
|
||||
|
||||
tokenizer_2.add_special_tokens(
|
||||
{"additional_special_tokens": ["<tok>"]},
|
||||
replace_additional_special_tokens=False,
|
||||
)
|
||||
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>", "<tok>"])
|
||||
|
||||
Reference in New Issue
Block a user