Fix bug which lowercases special tokens
This commit is contained in:
committed by
Lysandre Debut
parent
35401fe50f
commit
2670b0d682
@@ -115,8 +115,10 @@ class CommonTestCases:
|
|||||||
def test_added_tokens_do_lower_case(self):
|
def test_added_tokens_do_lower_case(self):
|
||||||
tokenizer = self.get_tokenizer(do_lower_case=True)
|
tokenizer = self.get_tokenizer(do_lower_case=True)
|
||||||
|
|
||||||
text = "aaaaa bbbbbb low cccccccccdddddddd l"
|
special_token = tokenizer.all_special_tokens[0]
|
||||||
text2 = "AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l"
|
|
||||||
|
text = special_token + " aaaaa bbbbbb low cccccccccdddddddd l " + special_token
|
||||||
|
text2 = special_token + " AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l " + special_token
|
||||||
|
|
||||||
toks0 = tokenizer.tokenize(text) # toks before adding new_toks
|
toks0 = tokenizer.tokenize(text) # toks before adding new_toks
|
||||||
|
|
||||||
@@ -141,7 +143,7 @@ class CommonTestCases:
|
|||||||
|
|
||||||
self.assertEqual(len(toks), len(toks2)) # Length should still be the same
|
self.assertEqual(len(toks), len(toks2)) # Length should still be the same
|
||||||
self.assertNotEqual(len(toks), len(toks0))
|
self.assertNotEqual(len(toks), len(toks0))
|
||||||
self.assertNotEqual(toks[0], toks2[0]) # But at least the first tokens should differ
|
self.assertNotEqual(toks[1], toks2[1]) # But at least the first non-special tokens should differ
|
||||||
|
|
||||||
def test_add_tokens_tokenizer(self):
|
def test_add_tokens_tokenizer(self):
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import json
|
|||||||
import six
|
import six
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
|
import re
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from .file_utils import cached_path, is_tf_available, is_torch_available
|
from .file_utils import cached_path, is_tf_available, is_torch_available
|
||||||
@@ -520,7 +521,7 @@ class PreTrainedTokenizer(object):
|
|||||||
to_add_tokens = []
|
to_add_tokens = []
|
||||||
for token in new_tokens:
|
for token in new_tokens:
|
||||||
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
|
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
|
||||||
if self.init_kwargs.get('do_lower_case', False):
|
if self.init_kwargs.get('do_lower_case', False) and token not in self.all_special_tokens:
|
||||||
token = token.lower()
|
token = token.lower()
|
||||||
if token != self.unk_token and \
|
if token != self.unk_token and \
|
||||||
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \
|
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \
|
||||||
@@ -615,8 +616,18 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
Take care of added tokens.
|
Take care of added tokens.
|
||||||
"""
|
"""
|
||||||
|
def lowercase_text(t):
|
||||||
|
# convert non-special tokens to lowercase
|
||||||
|
escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
|
||||||
|
pattern = r'(^' + r'|'.join(escaped_special_toks) + r')|' + \
|
||||||
|
r'(.+?)'
|
||||||
|
return re.sub(
|
||||||
|
pattern,
|
||||||
|
lambda m: m.groups()[0] or m.groups()[1].lower(),
|
||||||
|
t)
|
||||||
|
|
||||||
if self.init_kwargs.get('do_lower_case', False):
|
if self.init_kwargs.get('do_lower_case', False):
|
||||||
text = text.lower()
|
text = lowercase_text(text)
|
||||||
|
|
||||||
def split_on_token(tok, text):
|
def split_on_token(tok, text):
|
||||||
result = []
|
result = []
|
||||||
|
|||||||
Reference in New Issue
Block a user