[ TokenizationUtils] Fix add_special_tokens when the token is already there (#28520)
* fix adding special tokens when the token is already there. * add a test * add a test * nit * fix the test: make sure the order is preserved * Update tests/test_tokenization_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -943,14 +943,15 @@ class SpecialTokensMixin:
|
|||||||
isinstance(t, (str, AddedToken)) for t in value
|
isinstance(t, (str, AddedToken)) for t in value
|
||||||
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
|
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
|
||||||
|
|
||||||
to_add = set()
|
to_add = []
|
||||||
for token in value:
|
for token in value:
|
||||||
if isinstance(token, str):
|
if isinstance(token, str):
|
||||||
# for legacy purpose we default to stripping. `test_add_tokens_tokenizer` depends on this
|
# for legacy purpose we default to stripping. `test_add_tokens_tokenizer` depends on this
|
||||||
token = AddedToken(token, rstrip=False, lstrip=False, normalized=False, special=True)
|
token = AddedToken(token, rstrip=False, lstrip=False, normalized=False, special=True)
|
||||||
if str(token) not in self.additional_special_tokens:
|
if not replace_additional_special_tokens and str(token) in self.additional_special_tokens:
|
||||||
to_add.add(token)
|
continue
|
||||||
if replace_additional_special_tokens:
|
to_add.append(token)
|
||||||
|
if replace_additional_special_tokens and len(to_add) > 0:
|
||||||
setattr(self, key, list(to_add))
|
setattr(self, key, list(to_add))
|
||||||
else:
|
else:
|
||||||
self._additional_special_tokens.extend(to_add)
|
self._additional_special_tokens.extend(to_add)
|
||||||
|
|||||||
@@ -4139,3 +4139,28 @@ class TokenizerTesterMixin:
|
|||||||
_test_added_vocab_and_eos(
|
_test_added_vocab_and_eos(
|
||||||
EXPECTED_ADDED_TOKENS_DECODER, self.tokenizer_class, new_eos, tmp_dir_4
|
EXPECTED_ADDED_TOKENS_DECODER, self.tokenizer_class, new_eos, tmp_dir_4
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_special_token_addition(self):
|
||||||
|
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||||
|
# Create tokenizer and add an additional special token
|
||||||
|
tokenizer_1 = tokenizer.from_pretrained(pretrained_name)
|
||||||
|
tokenizer_1.add_special_tokens({"additional_special_tokens": ["<tok>"]})
|
||||||
|
self.assertEqual(tokenizer_1.additional_special_tokens, ["<tok>"])
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tokenizer_1.save_pretrained(tmp_dir)
|
||||||
|
# Load the above tokenizer and add the same special token a second time
|
||||||
|
tokenizer_2 = tokenizer.from_pretrained(pretrained_name)
|
||||||
|
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>"]})
|
||||||
|
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>"])
|
||||||
|
|
||||||
|
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>", "<other>"]})
|
||||||
|
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>", "<other>"])
|
||||||
|
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<other>", "<another>"]})
|
||||||
|
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>"])
|
||||||
|
|
||||||
|
tokenizer_2.add_special_tokens(
|
||||||
|
{"additional_special_tokens": ["<tok>"]},
|
||||||
|
replace_additional_special_tokens=False,
|
||||||
|
)
|
||||||
|
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>", "<tok>"])
|
||||||
|
|||||||
Reference in New Issue
Block a user