From 2670b0d682746e1fe94ab9c7b4d2fd7f4af03193 Mon Sep 17 00:00:00 2001 From: Michael Watkins Date: Wed, 4 Dec 2019 17:53:25 +0200 Subject: [PATCH] Fix bug which lowercases special tokens --- transformers/tests/tokenization_tests_commons.py | 8 +++++--- transformers/tokenization_utils.py | 15 +++++++++++++-- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/transformers/tests/tokenization_tests_commons.py b/transformers/tests/tokenization_tests_commons.py index faff003f4b..d904f0067e 100644 --- a/transformers/tests/tokenization_tests_commons.py +++ b/transformers/tests/tokenization_tests_commons.py @@ -115,8 +115,10 @@ class CommonTestCases: def test_added_tokens_do_lower_case(self): tokenizer = self.get_tokenizer(do_lower_case=True) - text = "aaaaa bbbbbb low cccccccccdddddddd l" - text2 = "AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l" + special_token = tokenizer.all_special_tokens[0] + + text = special_token + " aaaaa bbbbbb low cccccccccdddddddd l " + special_token + text2 = special_token + " AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l " + special_token toks0 = tokenizer.tokenize(text) # toks before adding new_toks @@ -141,7 +143,7 @@ class CommonTestCases: self.assertEqual(len(toks), len(toks2)) # Length should still be the same self.assertNotEqual(len(toks), len(toks0)) - self.assertNotEqual(toks[0], toks2[0]) # But at least the first tokens should differ + self.assertNotEqual(toks[1], toks2[1]) # But at least the first non-special tokens should differ def test_add_tokens_tokenizer(self): tokenizer = self.get_tokenizer() diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index 4c6cbd8986..eb22c50ebd 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -22,6 +22,7 @@ import json import six import copy import itertools +import re from io import open from .file_utils import cached_path, is_tf_available, is_torch_available @@ -520,7 +521,7 @@ class PreTrainedTokenizer(object): to_add_tokens = [] for token in new_tokens: assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode)) - if self.init_kwargs.get('do_lower_case', False): + if self.init_kwargs.get('do_lower_case', False) and token not in self.all_special_tokens: token = token.lower() if token != self.unk_token and \ self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \ @@ -615,8 +616,18 @@ class PreTrainedTokenizer(object): Take care of added tokens. """ + def lowercase_text(t): + # convert non-special tokens to lowercase + escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens] + pattern = r'(^' + r'|'.join(escaped_special_toks) + r')|' + \ + r'(.+?)' + return re.sub( + pattern, + lambda m: m.groups()[0] or m.groups()[1].lower(), + t) + if self.init_kwargs.get('do_lower_case', False): - text = text.lower() + text = lowercase_text(text) def split_on_token(tok, text): result = []