From 7246d3c2f93c4461f3ec8ada7a26a002d8f196ea Mon Sep 17 00:00:00 2001 From: Michael Watkins Date: Wed, 6 Nov 2019 13:18:16 +0200 Subject: [PATCH] Consider do_lower_case in PreTrainedTokenizer As pointed out in #1545, when using an uncased model, and adding a new uncased token, the tokenizer does not correctly identify this in the case that the input text contains the token in a cased format. For instance, if we load bert-base-uncased into BertTokenizer, and then use .add_tokens() to add "cool-token", we get the expected result for .tokenize('this is a cool-token'). However, we get a possibly unexpected result for .tokenize('this is a cOOl-Token'), which in fact mirrors the result for the former from before the new token was added. This commit adds - functionality to PreTrainedTokenizer to handle this situation in case a tokenizer (currently Bert, DistilBert, and XLNet) has the do_lower_case=True kwarg by: 1) lowercasing tokens added with .add_tokens() 2) lowercasing text at the beginning of .tokenize() - new common test case for tokenizers https://github.com/huggingface/transformers/issues/1545 --- .../tests/tokenization_tests_commons.py | 31 ++++++++++++++++++- transformers/tokenization_utils.py | 5 +++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/transformers/tests/tokenization_tests_commons.py b/transformers/tests/tokenization_tests_commons.py index a921696b77..287e6fc7b3 100644 --- a/transformers/tests/tokenization_tests_commons.py +++ b/transformers/tests/tokenization_tests_commons.py @@ -110,6 +110,36 @@ class CommonTestCases: self.assertListEqual(subwords, subwords_loaded) + def test_added_tokens_do_lower_case(self): + tokenizer = self.get_tokenizer(do_lower_case=True) + + text = "aaaaa bbbbbb low cccccccccdddddddd l" + text2 = "AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l" + + toks0 = tokenizer.tokenize(text) # toks before adding new_toks + + new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd", 'AAAAA BBBBBB', 'CCCCCCCCCDDDDDDDD'] + added = tokenizer.add_tokens(new_toks) + self.assertEqual(added, 2) + + toks = tokenizer.tokenize(text) + toks2 = tokenizer.tokenize(text2) + + self.assertEqual(len(toks), len(toks2)) + self.assertNotEqual(len(toks), len(toks0)) # toks0 should be longer + self.assertListEqual(toks, toks2) + + tokenizer = self.get_tokenizer(do_lower_case=False) + + added = tokenizer.add_tokens(new_toks) + self.assertEqual(added, 4) + + toks = tokenizer.tokenize(text) + toks2 = tokenizer.tokenize(text2) + + self.assertEqual(len(toks), len(toks2)) # Length should still be the same + self.assertNotEqual(len(toks), len(toks0)) + self.assertNotEqual(toks[0], toks2[0]) # But at least the first tokens should differ def test_add_tokens_tokenizer(self): tokenizer = self.get_tokenizer() @@ -160,7 +190,6 @@ class CommonTestCases: self.assertEqual(tokens[0], tokenizer.eos_token_id) self.assertEqual(tokens[-2], tokenizer.pad_token_id) - def test_required_methods_tokenizer(self): tokenizer = self.get_tokenizer() input_text, output_text = self.get_input_output_texts() diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index cd14cc4582..fc31c10d25 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -512,6 +512,8 @@ class PreTrainedTokenizer(object): to_add_tokens = [] for token in new_tokens: assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode)) + if self.init_kwargs.get('do_lower_case', False): + token = token.lower() if token != self.unk_token and \ self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \ token not in to_add_tokens: @@ -605,6 +607,9 @@ class PreTrainedTokenizer(object): Take care of added tokens. """ + if self.init_kwargs.get('do_lower_case', False): + text = text.lower() + def split_on_token(tok, text): result = [] split_text = text.split(tok)