Added option to setup pretrained tokenizer arguments
This commit is contained in:
@@ -63,6 +63,23 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'bert-base-cased-finetuned-mrpc': 512,
|
||||
}
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
'bert-base-uncased': {'do_lower_case': True},
|
||||
'bert-large-uncased': {'do_lower_case': True},
|
||||
'bert-base-cased': {'do_lower_case': False},
|
||||
'bert-large-cased': {'do_lower_case': False},
|
||||
'bert-base-multilingual-uncased': {'do_lower_case': True},
|
||||
'bert-base-multilingual-cased': {'do_lower_case': False},
|
||||
'bert-base-chinese': {'do_lower_case': False},
|
||||
'bert-base-german-cased': {'do_lower_case': False},
|
||||
'bert-large-uncased-whole-word-masking': {'do_lower_case': True},
|
||||
'bert-large-cased-whole-word-masking': {'do_lower_case': False},
|
||||
'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True},
|
||||
'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False},
|
||||
'bert-base-cased-finetuned-mrpc': {'do_lower_case': False},
|
||||
}
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
@@ -100,6 +117,7 @@ class BertTokenizer(PreTrainedTokenizer):
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
|
||||
@@ -199,24 +217,6 @@ class BertTokenizer(PreTrainedTokenizer):
|
||||
index += 1
|
||||
return (vocab_file,)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
""" Instantiate a BertTokenizer from pre-trained vocabulary files.
|
||||
"""
|
||||
if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
|
||||
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
|
||||
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
|
||||
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
|
||||
"you may want to check this behavior.")
|
||||
kwargs['do_lower_case'] = False
|
||||
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
|
||||
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
|
||||
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
|
||||
"but you may want to check this behavior.")
|
||||
kwargs['do_lower_case'] = True
|
||||
|
||||
return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
Reference in New Issue
Block a user