From 716df5fb7ec8b24b0332442e0fbf30ea7526e569 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 16 Jan 2024 16:36:29 +0100 Subject: [PATCH] [ `TokenizationUtils`] Fix `add_special_tokens` when the token is already there (#28520) * fix adding special tokens when the token is already there. * add a test * add a test * nit * fix the test: make sure the order is preserved * Update tests/test_tokenization_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/tokenization_utils_base.py | 9 ++++---- tests/test_tokenization_common.py | 25 +++++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 11be7aac2d..b7377ea314 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -943,14 +943,15 @@ class SpecialTokensMixin: isinstance(t, (str, AddedToken)) for t in value ), f"Tokens {value} for key {key} should all be str or AddedToken instances" - to_add = set() + to_add = [] for token in value: if isinstance(token, str): # for legacy purpose we default to stripping. `test_add_tokens_tokenizer` depends on this token = AddedToken(token, rstrip=False, lstrip=False, normalized=False, special=True) - if str(token) not in self.additional_special_tokens: - to_add.add(token) - if replace_additional_special_tokens: + if not replace_additional_special_tokens and str(token) in self.additional_special_tokens: + continue + to_add.append(token) + if replace_additional_special_tokens and len(to_add) > 0: setattr(self, key, list(to_add)) else: self._additional_special_tokens.extend(to_add) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 9b60b2f186..e5b9a34702 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -4139,3 +4139,28 @@ class TokenizerTesterMixin: _test_added_vocab_and_eos( EXPECTED_ADDED_TOKENS_DECODER, self.tokenizer_class, new_eos, tmp_dir_4 ) + + def test_special_token_addition(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + # Create tokenizer and add an additional special token + tokenizer_1 = tokenizer.from_pretrained(pretrained_name) + tokenizer_1.add_special_tokens({"additional_special_tokens": [""]}) + self.assertEqual(tokenizer_1.additional_special_tokens, [""]) + with tempfile.TemporaryDirectory() as tmp_dir: + tokenizer_1.save_pretrained(tmp_dir) + # Load the above tokenizer and add the same special token a second time + tokenizer_2 = tokenizer.from_pretrained(pretrained_name) + tokenizer_2.add_special_tokens({"additional_special_tokens": [""]}) + self.assertEqual(tokenizer_2.additional_special_tokens, [""]) + + tokenizer_2.add_special_tokens({"additional_special_tokens": ["", ""]}) + self.assertEqual(tokenizer_2.additional_special_tokens, ["", ""]) + tokenizer_2.add_special_tokens({"additional_special_tokens": ["", ""]}) + self.assertEqual(tokenizer_2.additional_special_tokens, ["", ""]) + + tokenizer_2.add_special_tokens( + {"additional_special_tokens": [""]}, + replace_additional_special_tokens=False, + ) + self.assertEqual(tokenizer_2.additional_special_tokens, ["", "", ""])