diff --git a/pytorch_pretrained_bert/tokenization.py b/pytorch_pretrained_bert/tokenization.py index 9cfb3d8ce9..9cb36a1b46 100644 --- a/pytorch_pretrained_bert/tokenization.py +++ b/pytorch_pretrained_bert/tokenization.py @@ -75,7 +75,8 @@ def whitespace_tokenize(text): class BertTokenizer(object): """Runs end-to-end tokenization: punctuation splitting + wordpiece""" - def __init__(self, vocab_file, do_lower_case=True, max_len=None): + def __init__(self, vocab_file, do_lower_case=True, max_len=None, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): if not os.path.isfile(vocab_file): raise ValueError( "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " @@ -83,7 +84,8 @@ class BertTokenizer(object): self.vocab = load_vocab(vocab_file) self.ids_to_tokens = collections.OrderedDict( [(ids, tok) for tok, ids in self.vocab.items()]) - self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, + never_split=never_split) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.max_len = max_len if max_len is not None else int(1e12) @@ -156,13 +158,16 @@ class BertTokenizer(object): class BasicTokenizer(object): """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" - def __init__(self, do_lower_case=True): + def __init__(self, + do_lower_case=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): """Constructs a BasicTokenizer. Args: do_lower_case: Whether to lower case the input. """ self.do_lower_case = do_lower_case + self.never_split = never_split def tokenize(self, text): """Tokenizes a piece of text.""" @@ -198,6 +203,8 @@ class BasicTokenizer(object): def _run_split_on_punc(self, text): """Splits punctuation on a piece of text.""" + if text in self.never_split: + return [text] chars = list(text) i = 0 start_new_word = True