Fix bug which lowercases special tokens

This commit is contained in:
Michael Watkins
2019-12-04 17:53:25 +02:00
committed by Lysandre Debut
parent 35401fe50f
commit 2670b0d682
2 changed files with 18 additions and 5 deletions

View File

@@ -22,6 +22,7 @@ import json
import six
import copy
import itertools
import re
from io import open
from .file_utils import cached_path, is_tf_available, is_torch_available
@@ -520,7 +521,7 @@ class PreTrainedTokenizer(object):
to_add_tokens = []
for token in new_tokens:
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
if self.init_kwargs.get('do_lower_case', False):
if self.init_kwargs.get('do_lower_case', False) and token not in self.all_special_tokens:
token = token.lower()
if token != self.unk_token and \
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \
@@ -615,8 +616,18 @@ class PreTrainedTokenizer(object):
Take care of added tokens.
"""
def lowercase_text(t):
# convert non-special tokens to lowercase
escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
pattern = r'(^' + r'|'.join(escaped_special_toks) + r')|' + \
r'(.+?)'
return re.sub(
pattern,
lambda m: m.groups()[0] or m.groups()[1].lower(),
t)
if self.init_kwargs.get('do_lower_case', False):
text = text.lower()
text = lowercase_text(text)
def split_on_token(tok, text):
result = []