never split some text
This commit is contained in:
@@ -75,7 +75,8 @@ def whitespace_tokenize(text):
|
|||||||
class BertTokenizer(object):
|
class BertTokenizer(object):
|
||||||
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
|
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
|
||||||
|
|
||||||
def __init__(self, vocab_file, do_lower_case=True, max_len=None):
|
def __init__(self, vocab_file, do_lower_case=True, max_len=None,
|
||||||
|
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
|
||||||
if not os.path.isfile(vocab_file):
|
if not os.path.isfile(vocab_file):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||||
@@ -83,7 +84,8 @@ class BertTokenizer(object):
|
|||||||
self.vocab = load_vocab(vocab_file)
|
self.vocab = load_vocab(vocab_file)
|
||||||
self.ids_to_tokens = collections.OrderedDict(
|
self.ids_to_tokens = collections.OrderedDict(
|
||||||
[(ids, tok) for tok, ids in self.vocab.items()])
|
[(ids, tok) for tok, ids in self.vocab.items()])
|
||||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
|
||||||
|
never_split=never_split)
|
||||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||||
self.max_len = max_len if max_len is not None else int(1e12)
|
self.max_len = max_len if max_len is not None else int(1e12)
|
||||||
|
|
||||||
@@ -156,13 +158,16 @@ class BertTokenizer(object):
|
|||||||
class BasicTokenizer(object):
|
class BasicTokenizer(object):
|
||||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||||
|
|
||||||
def __init__(self, do_lower_case=True):
|
def __init__(self,
|
||||||
|
do_lower_case=True,
|
||||||
|
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
|
||||||
"""Constructs a BasicTokenizer.
|
"""Constructs a BasicTokenizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
do_lower_case: Whether to lower case the input.
|
do_lower_case: Whether to lower case the input.
|
||||||
"""
|
"""
|
||||||
self.do_lower_case = do_lower_case
|
self.do_lower_case = do_lower_case
|
||||||
|
self.never_split = never_split
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
"""Tokenizes a piece of text."""
|
"""Tokenizes a piece of text."""
|
||||||
@@ -198,6 +203,8 @@ class BasicTokenizer(object):
|
|||||||
|
|
||||||
def _run_split_on_punc(self, text):
|
def _run_split_on_punc(self, text):
|
||||||
"""Splits punctuation on a piece of text."""
|
"""Splits punctuation on a piece of text."""
|
||||||
|
if text in self.never_split:
|
||||||
|
return [text]
|
||||||
chars = list(text)
|
chars = list(text)
|
||||||
i = 0
|
i = 0
|
||||||
start_new_word = True
|
start_new_word = True
|
||||||
|
|||||||
Reference in New Issue
Block a user