fixed loading pre-trained tokenizer from directory
This commit is contained in:
@@ -478,7 +478,7 @@ class PreTrainedBertModel(nn.Module):
|
|||||||
"associated to this path or url.".format(
|
"associated to this path or url.".format(
|
||||||
pretrained_model_name,
|
pretrained_model_name,
|
||||||
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
||||||
pretrained_model_name))
|
archive_file))
|
||||||
return None
|
return None
|
||||||
if resolved_archive_file == archive_file:
|
if resolved_archive_file == archive_file:
|
||||||
logger.info("loading archive file {}".format(archive_file))
|
logger.info("loading archive file {}".format(archive_file))
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
|||||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
||||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
||||||
}
|
}
|
||||||
|
VOCAB_NAME = 'vocab.txt'
|
||||||
|
|
||||||
|
|
||||||
def load_vocab(vocab_file):
|
def load_vocab(vocab_file):
|
||||||
@@ -100,7 +101,7 @@ class BertTokenizer(object):
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name, do_lower_case=True):
|
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||||
Download and cache the pre-trained model file if needed.
|
Download and cache the pre-trained model file if needed.
|
||||||
@@ -109,16 +110,11 @@ class BertTokenizer(object):
|
|||||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
|
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
|
||||||
else:
|
else:
|
||||||
vocab_file = pretrained_model_name
|
vocab_file = pretrained_model_name
|
||||||
|
if os.path.isdir(vocab_file):
|
||||||
|
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_vocab_file = cached_path(vocab_file)
|
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||||
if resolved_vocab_file == vocab_file:
|
|
||||||
logger.info("loading vocabulary file {}".format(vocab_file))
|
|
||||||
else:
|
|
||||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
|
||||||
vocab_file, resolved_vocab_file))
|
|
||||||
# Instantiate tokenizer.
|
|
||||||
tokenizer = cls(resolved_vocab_file, do_lower_case)
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
@@ -126,8 +122,15 @@ class BertTokenizer(object):
|
|||||||
"associated to this path or url.".format(
|
"associated to this path or url.".format(
|
||||||
pretrained_model_name,
|
pretrained_model_name,
|
||||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||||
pretrained_model_name))
|
vocab_file))
|
||||||
tokenizer = None
|
return None
|
||||||
|
if resolved_vocab_file == vocab_file:
|
||||||
|
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||||
|
else:
|
||||||
|
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||||
|
vocab_file, resolved_vocab_file))
|
||||||
|
# Instantiate tokenizer.
|
||||||
|
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user