Consider do_lower_case in PreTrainedTokenizer
As pointed out in #1545, when using an uncased model, and adding a new uncased token, the tokenizer does not correctly identify this in the case that the input text contains the token in a cased format. For instance, if we load bert-base-uncased into BertTokenizer, and then use .add_tokens() to add "cool-token", we get the expected result for .tokenize('this is a cool-token'). However, we get a possibly unexpected result for .tokenize('this is a cOOl-Token'), which in fact mirrors the result for the former from before the new token was added. This commit adds - functionality to PreTrainedTokenizer to handle this situation in case a tokenizer (currently Bert, DistilBert, and XLNet) has the do_lower_case=True kwarg by: 1) lowercasing tokens added with .add_tokens() 2) lowercasing text at the beginning of .tokenize() - new common test case for tokenizers https://github.com/huggingface/transformers/issues/1545
This commit is contained in:
@@ -110,6 +110,36 @@ class CommonTestCases:
|
||||
|
||||
self.assertListEqual(subwords, subwords_loaded)
|
||||
|
||||
def test_added_tokens_do_lower_case(self):
|
||||
tokenizer = self.get_tokenizer(do_lower_case=True)
|
||||
|
||||
text = "aaaaa bbbbbb low cccccccccdddddddd l"
|
||||
text2 = "AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l"
|
||||
|
||||
toks0 = tokenizer.tokenize(text) # toks before adding new_toks
|
||||
|
||||
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd", 'AAAAA BBBBBB', 'CCCCCCCCCDDDDDDDD']
|
||||
added = tokenizer.add_tokens(new_toks)
|
||||
self.assertEqual(added, 2)
|
||||
|
||||
toks = tokenizer.tokenize(text)
|
||||
toks2 = tokenizer.tokenize(text2)
|
||||
|
||||
self.assertEqual(len(toks), len(toks2))
|
||||
self.assertNotEqual(len(toks), len(toks0)) # toks0 should be longer
|
||||
self.assertListEqual(toks, toks2)
|
||||
|
||||
tokenizer = self.get_tokenizer(do_lower_case=False)
|
||||
|
||||
added = tokenizer.add_tokens(new_toks)
|
||||
self.assertEqual(added, 4)
|
||||
|
||||
toks = tokenizer.tokenize(text)
|
||||
toks2 = tokenizer.tokenize(text2)
|
||||
|
||||
self.assertEqual(len(toks), len(toks2)) # Length should still be the same
|
||||
self.assertNotEqual(len(toks), len(toks0))
|
||||
self.assertNotEqual(toks[0], toks2[0]) # But at least the first tokens should differ
|
||||
|
||||
def test_add_tokens_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
@@ -160,7 +190,6 @@ class CommonTestCases:
|
||||
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
||||
self.assertEqual(tokens[-2], tokenizer.pad_token_id)
|
||||
|
||||
|
||||
def test_required_methods_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
input_text, output_text = self.get_input_output_texts()
|
||||
|
||||
Reference in New Issue
Block a user