Fix add_special_tokens on fast tokenizers (#4531)
This commit is contained in:
@@ -2400,15 +2400,20 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def add_special_tokens(self, special_tokens_dict: dict) -> int:
|
def add_special_tokens(self, special_tokens_dict: dict) -> int:
|
||||||
# Map special tokens to class attributes (self.pad_token...)
|
# Map special tokens to class attributes (self.pad_token...)
|
||||||
num_added_tokens = super().add_special_tokens(special_tokens_dict)
|
super().add_special_tokens(special_tokens_dict)
|
||||||
|
|
||||||
# If the backend tokenizer the only specificities of special tokens are that
|
# If the backend tokenizer the only specificities of special tokens are that
|
||||||
# - they will never be processed by the model, and
|
# - they will never be processed by the model, and
|
||||||
# - they will be removed while decoding.
|
# - they will be removed while decoding.
|
||||||
# But they are not mapped to special attributes in the backend so we can just
|
# But they are not mapped to special attributes in the backend so we can just
|
||||||
# send a list.
|
# send a list.
|
||||||
tokens = flatten(special_tokens_dict.values())
|
tokens = []
|
||||||
self._tokenizer.add_special_tokens(tokens)
|
for token in special_tokens_dict.values():
|
||||||
|
if isinstance(token, list):
|
||||||
|
tokens += token
|
||||||
|
else:
|
||||||
|
tokens += [token]
|
||||||
|
num_added_tokens = self._tokenizer.add_special_tokens(tokens)
|
||||||
|
|
||||||
return num_added_tokens
|
return num_added_tokens
|
||||||
|
|
||||||
|
|||||||
@@ -221,6 +221,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
|||||||
self.assertEqual(len(tokenizer_r), vocab_size + 3)
|
self.assertEqual(len(tokenizer_r), vocab_size + 3)
|
||||||
|
|
||||||
self.assertEqual(tokenizer_r.add_special_tokens({}), 0)
|
self.assertEqual(tokenizer_r.add_special_tokens({}), 0)
|
||||||
|
self.assertEqual(tokenizer_r.add_special_tokens({"bos_token": "[BOS]", "eos_token": "[EOS]"}), 2)
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
AssertionError, tokenizer_r.add_special_tokens, {"additional_special_tokens": "<testtoken1>"}
|
AssertionError, tokenizer_r.add_special_tokens, {"additional_special_tokens": "<testtoken1>"}
|
||||||
)
|
)
|
||||||
@@ -228,7 +229,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
tokenizer_r.add_special_tokens({"additional_special_tokens": ["<testtoken3>", "<testtoken4>"]}), 2
|
tokenizer_r.add_special_tokens({"additional_special_tokens": ["<testtoken3>", "<testtoken4>"]}), 2
|
||||||
)
|
)
|
||||||
self.assertEqual(len(tokenizer_r), vocab_size + 6)
|
self.assertEqual(len(tokenizer_r), vocab_size + 8)
|
||||||
|
|
||||||
def assert_offsets_mapping(self, tokenizer_r):
|
def assert_offsets_mapping(self, tokenizer_r):
|
||||||
text = "Wonderful no inspiration example with subtoken"
|
text = "Wonderful no inspiration example with subtoken"
|
||||||
|
|||||||
Reference in New Issue
Block a user