Fix bug which lowercases special tokens
This commit is contained in:
committed by
Lysandre Debut
parent
35401fe50f
commit
2670b0d682
@@ -22,6 +22,7 @@ import json
|
||||
import six
|
||||
import copy
|
||||
import itertools
|
||||
import re
|
||||
from io import open
|
||||
|
||||
from .file_utils import cached_path, is_tf_available, is_torch_available
|
||||
@@ -520,7 +521,7 @@ class PreTrainedTokenizer(object):
|
||||
to_add_tokens = []
|
||||
for token in new_tokens:
|
||||
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
|
||||
if self.init_kwargs.get('do_lower_case', False):
|
||||
if self.init_kwargs.get('do_lower_case', False) and token not in self.all_special_tokens:
|
||||
token = token.lower()
|
||||
if token != self.unk_token and \
|
||||
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \
|
||||
@@ -615,8 +616,18 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
Take care of added tokens.
|
||||
"""
|
||||
def lowercase_text(t):
|
||||
# convert non-special tokens to lowercase
|
||||
escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
|
||||
pattern = r'(^' + r'|'.join(escaped_special_toks) + r')|' + \
|
||||
r'(.+?)'
|
||||
return re.sub(
|
||||
pattern,
|
||||
lambda m: m.groups()[0] or m.groups()[1].lower(),
|
||||
t)
|
||||
|
||||
if self.init_kwargs.get('do_lower_case', False):
|
||||
text = text.lower()
|
||||
text = lowercase_text(text)
|
||||
|
||||
def split_on_token(tok, text):
|
||||
result = []
|
||||
|
||||
Reference in New Issue
Block a user