[ TokenizationUtils] Fix add_special_tokens when the token is already there (#28520)
* fix adding special tokens when the token is already there. * add a test * add a test * nit * fix the test: make sure the order is preserved * Update tests/test_tokenization_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -4139,3 +4139,28 @@ class TokenizerTesterMixin:
|
||||
_test_added_vocab_and_eos(
|
||||
EXPECTED_ADDED_TOKENS_DECODER, self.tokenizer_class, new_eos, tmp_dir_4
|
||||
)
|
||||
|
||||
def test_special_token_addition(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
# Create tokenizer and add an additional special token
|
||||
tokenizer_1 = tokenizer.from_pretrained(pretrained_name)
|
||||
tokenizer_1.add_special_tokens({"additional_special_tokens": ["<tok>"]})
|
||||
self.assertEqual(tokenizer_1.additional_special_tokens, ["<tok>"])
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer_1.save_pretrained(tmp_dir)
|
||||
# Load the above tokenizer and add the same special token a second time
|
||||
tokenizer_2 = tokenizer.from_pretrained(pretrained_name)
|
||||
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>"]})
|
||||
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>"])
|
||||
|
||||
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>", "<other>"]})
|
||||
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>", "<other>"])
|
||||
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<other>", "<another>"]})
|
||||
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>"])
|
||||
|
||||
tokenizer_2.add_special_tokens(
|
||||
{"additional_special_tokens": ["<tok>"]},
|
||||
replace_additional_special_tokens=False,
|
||||
)
|
||||
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>", "<tok>"])
|
||||
|
||||
Reference in New Issue
Block a user