Fix add_special_tokens on fast tokenizers (#4531)

This commit is contained in:
Anthony MOI
2020-05-28 10:54:45 -04:00
committed by GitHub
parent e444648a30
commit 5e737018e1
2 changed files with 10 additions and 4 deletions

View File

@@ -2400,15 +2400,20 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
def add_special_tokens(self, special_tokens_dict: dict) -> int: def add_special_tokens(self, special_tokens_dict: dict) -> int:
# Map special tokens to class attributes (self.pad_token...) # Map special tokens to class attributes (self.pad_token...)
num_added_tokens = super().add_special_tokens(special_tokens_dict) super().add_special_tokens(special_tokens_dict)
# If the backend tokenizer the only specificities of special tokens are that # If the backend tokenizer the only specificities of special tokens are that
# - they will never be processed by the model, and # - they will never be processed by the model, and
# - they will be removed while decoding. # - they will be removed while decoding.
# But they are not mapped to special attributes in the backend so we can just # But they are not mapped to special attributes in the backend so we can just
# send a list. # send a list.
tokens = flatten(special_tokens_dict.values()) tokens = []
self._tokenizer.add_special_tokens(tokens) for token in special_tokens_dict.values():
if isinstance(token, list):
tokens += token
else:
tokens += [token]
num_added_tokens = self._tokenizer.add_special_tokens(tokens)
return num_added_tokens return num_added_tokens

View File

@@ -221,6 +221,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
self.assertEqual(len(tokenizer_r), vocab_size + 3) self.assertEqual(len(tokenizer_r), vocab_size + 3)
self.assertEqual(tokenizer_r.add_special_tokens({}), 0) self.assertEqual(tokenizer_r.add_special_tokens({}), 0)
self.assertEqual(tokenizer_r.add_special_tokens({"bos_token": "[BOS]", "eos_token": "[EOS]"}), 2)
self.assertRaises( self.assertRaises(
AssertionError, tokenizer_r.add_special_tokens, {"additional_special_tokens": "<testtoken1>"} AssertionError, tokenizer_r.add_special_tokens, {"additional_special_tokens": "<testtoken1>"}
) )
@@ -228,7 +229,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
self.assertEqual( self.assertEqual(
tokenizer_r.add_special_tokens({"additional_special_tokens": ["<testtoken3>", "<testtoken4>"]}), 2 tokenizer_r.add_special_tokens({"additional_special_tokens": ["<testtoken3>", "<testtoken4>"]}), 2
) )
self.assertEqual(len(tokenizer_r), vocab_size + 6) self.assertEqual(len(tokenizer_r), vocab_size + 8)
def assert_offsets_mapping(self, tokenizer_r): def assert_offsets_mapping(self, tokenizer_r):
text = "Wonderful no inspiration example with subtoken" text = "Wonderful no inspiration example with subtoken"