From e14c6b52e37876ee642ffde49367c51b0d374f41 Mon Sep 17 00:00:00 2001 From: John Hewitt Date: Tue, 26 Feb 2019 20:11:24 -0800 Subject: [PATCH] add BertTokenizer flag to skip basic tokenization --- README.md | 5 +++-- pytorch_pretrained_bert/tokenization.py | 25 ++++++++++++++++++------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 0e3ead57e9..c9b8549843 100644 --- a/README.md +++ b/README.md @@ -507,7 +507,7 @@ where Examples: ```python # BERT -tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) +tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, do_basic_tokenize=True) model = BertForSequenceClassification.from_pretrained('bert-base-uncased') # OpenAI GPT @@ -803,11 +803,12 @@ This model *outputs*: `BertTokenizer` perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization. -This class has four arguments: +This class has five arguments: - `vocab_file`: path to a vocabulary file. - `do_lower_case`: convert text to lower-case while tokenizing. **Default = True**. - `max_len`: max length to filter the input of the Transformer. Default to pre-trained value for the model if `None`. **Default = None** +- `do_basic_tokenize`: Do basic tokenization before wordpice tokenization. Set to false if text is pre-tokenized. **Default = True**. - `never_split`: a list of tokens that should not be splitted during tokenization. **Default = `["[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"]`** and three methods: diff --git a/pytorch_pretrained_bert/tokenization.py b/pytorch_pretrained_bert/tokenization.py index 1fabea852a..9ee8be2039 100644 --- a/pytorch_pretrained_bert/tokenization.py +++ b/pytorch_pretrained_bert/tokenization.py @@ -74,8 +74,14 @@ def whitespace_tokenize(text): class BertTokenizer(object): """Runs end-to-end tokenization: punctuation splitting + wordpiece""" - def __init__(self, vocab_file, do_lower_case=True, max_len=None, + def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BertTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + do_wordpiece_only: Whether to do basic tokenization before wordpiece. + """ if not os.path.isfile(vocab_file): raise ValueError( "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " @@ -83,16 +89,21 @@ class BertTokenizer(object): self.vocab = load_vocab(vocab_file) self.ids_to_tokens = collections.OrderedDict( [(ids, tok) for tok, ids in self.vocab.items()]) - self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, - never_split=never_split) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, + never_split=never_split) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.max_len = max_len if max_len is not None else int(1e12) def tokenize(self, text): - split_tokens = [] - for token in self.basic_tokenizer.tokenize(text): - for sub_token in self.wordpiece_tokenizer.tokenize(token): - split_tokens.append(sub_token) + if self.do_basic_tokenize: + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) return split_tokens def convert_tokens_to_ids(self, tokens):