add BertTokenizer flag to skip basic tokenization
This commit is contained in:
@@ -507,7 +507,7 @@ where
|
|||||||
Examples:
|
Examples:
|
||||||
```python
|
```python
|
||||||
# BERT
|
# BERT
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, do_basic_tokenize=True)
|
||||||
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
||||||
|
|
||||||
# OpenAI GPT
|
# OpenAI GPT
|
||||||
@@ -803,11 +803,12 @@ This model *outputs*:
|
|||||||
|
|
||||||
`BertTokenizer` perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization.
|
`BertTokenizer` perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization.
|
||||||
|
|
||||||
This class has four arguments:
|
This class has five arguments:
|
||||||
|
|
||||||
- `vocab_file`: path to a vocabulary file.
|
- `vocab_file`: path to a vocabulary file.
|
||||||
- `do_lower_case`: convert text to lower-case while tokenizing. **Default = True**.
|
- `do_lower_case`: convert text to lower-case while tokenizing. **Default = True**.
|
||||||
- `max_len`: max length to filter the input of the Transformer. Default to pre-trained value for the model if `None`. **Default = None**
|
- `max_len`: max length to filter the input of the Transformer. Default to pre-trained value for the model if `None`. **Default = None**
|
||||||
|
- `do_basic_tokenize`: Do basic tokenization before wordpice tokenization. Set to false if text is pre-tokenized. **Default = True**.
|
||||||
- `never_split`: a list of tokens that should not be splitted during tokenization. **Default = `["[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"]`**
|
- `never_split`: a list of tokens that should not be splitted during tokenization. **Default = `["[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"]`**
|
||||||
|
|
||||||
and three methods:
|
and three methods:
|
||||||
|
|||||||
@@ -74,8 +74,14 @@ def whitespace_tokenize(text):
|
|||||||
class BertTokenizer(object):
|
class BertTokenizer(object):
|
||||||
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
|
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
|
||||||
|
|
||||||
def __init__(self, vocab_file, do_lower_case=True, max_len=None,
|
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
|
||||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
|
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
|
||||||
|
"""Constructs a BertTokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_lower_case: Whether to lower case the input.
|
||||||
|
do_wordpiece_only: Whether to do basic tokenization before wordpiece.
|
||||||
|
"""
|
||||||
if not os.path.isfile(vocab_file):
|
if not os.path.isfile(vocab_file):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||||
@@ -83,16 +89,21 @@ class BertTokenizer(object):
|
|||||||
self.vocab = load_vocab(vocab_file)
|
self.vocab = load_vocab(vocab_file)
|
||||||
self.ids_to_tokens = collections.OrderedDict(
|
self.ids_to_tokens = collections.OrderedDict(
|
||||||
[(ids, tok) for tok, ids in self.vocab.items()])
|
[(ids, tok) for tok, ids in self.vocab.items()])
|
||||||
|
self.do_basic_tokenize = do_basic_tokenize
|
||||||
|
if do_basic_tokenize:
|
||||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
|
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
|
||||||
never_split=never_split)
|
never_split=never_split)
|
||||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||||
self.max_len = max_len if max_len is not None else int(1e12)
|
self.max_len = max_len if max_len is not None else int(1e12)
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
|
if self.do_basic_tokenize:
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
for token in self.basic_tokenizer.tokenize(text):
|
for token in self.basic_tokenizer.tokenize(text):
|
||||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||||
split_tokens.append(sub_token)
|
split_tokens.append(sub_token)
|
||||||
|
else:
|
||||||
|
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
||||||
return split_tokens
|
return split_tokens
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens):
|
def convert_tokens_to_ids(self, tokens):
|
||||||
|
|||||||
Reference in New Issue
Block a user