Tokenizer.from_pretrained: fetch all possible files remotely
This commit is contained in:
@@ -200,6 +200,8 @@ class PretrainedConfig(object):
|
||||
resume_download=resume_download,
|
||||
)
|
||||
# Load config dict
|
||||
if resolved_config_file is None:
|
||||
raise EnvironmentError
|
||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||
|
||||
except EnvironmentError:
|
||||
@@ -210,7 +212,7 @@ class PretrainedConfig(object):
|
||||
else:
|
||||
msg = (
|
||||
"Model name '{}' was not found in model name list. "
|
||||
"We assumed '{}' was a path or url to a configuration file named {} or "
|
||||
"We assumed '{}' was a path, a model identifier, or url to a configuration file named {} or "
|
||||
"a directory containing such a file but couldn't find any such file at this path or url.".format(
|
||||
pretrained_model_name_or_path, config_file, CONFIG_NAME,
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ import tempfile
|
||||
from contextlib import contextmanager
|
||||
from functools import partial, wraps
|
||||
from hashlib import sha256
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
@@ -122,7 +123,7 @@ def is_remote_url(url_or_filename):
|
||||
return parsed.scheme in ("http", "https", "s3")
|
||||
|
||||
|
||||
def hf_bucket_url(identifier, postfix=None, cdn=False):
|
||||
def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
|
||||
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
|
||||
if postfix is None:
|
||||
return "/".join((endpoint, identifier))
|
||||
@@ -182,7 +183,7 @@ def filename_to_url(filename, cache_dir=None):
|
||||
|
||||
def cached_path(
|
||||
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
|
||||
):
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Given something that might be a URL (or might be a local path),
|
||||
determine which. If it's a URL, download the file and cache it, and
|
||||
@@ -193,6 +194,10 @@ def cached_path(
|
||||
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
|
||||
resume_download: if True, resume the download if incompletly recieved file is found.
|
||||
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
|
||||
|
||||
Return:
|
||||
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
||||
Local path (string) otherwise
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
@@ -306,10 +311,14 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
|
||||
|
||||
def get_from_cache(
|
||||
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None
|
||||
):
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Given a URL, look for the corresponding dataset in the local cache.
|
||||
Given a URL, look for the corresponding file in the local cache.
|
||||
If it's not there, download it. Then return the path to the cached file.
|
||||
|
||||
Return:
|
||||
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
||||
Local path (string) otherwise
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
@@ -336,16 +345,25 @@ def get_from_cache(
|
||||
# get cache path to put the file
|
||||
cache_path = os.path.join(cache_dir, filename)
|
||||
|
||||
# If we don't have a connection (etag is None) and can't identify the file
|
||||
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
|
||||
# try to get the last downloaded one
|
||||
if not os.path.exists(cache_path) and etag is None:
|
||||
matching_files = [
|
||||
file
|
||||
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
|
||||
if not file.endswith(".json") and not file.endswith(".lock")
|
||||
]
|
||||
if matching_files:
|
||||
cache_path = os.path.join(cache_dir, matching_files[-1])
|
||||
if etag is None:
|
||||
if os.path.exists(cache_path):
|
||||
return cache_path
|
||||
else:
|
||||
matching_files = [
|
||||
file
|
||||
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
|
||||
if not file.endswith(".json") and not file.endswith(".lock")
|
||||
]
|
||||
if len(matching_files) > 0:
|
||||
return os.path.join(cache_dir, matching_files[-1])
|
||||
else:
|
||||
return None
|
||||
|
||||
# From now on, etag is not None.
|
||||
if os.path.exists(cache_path) and not force_download:
|
||||
return cache_path
|
||||
|
||||
# Prevent parallel downloads of the same file with a lock.
|
||||
lock_path = cache_path + ".lock"
|
||||
@@ -368,29 +386,26 @@ def get_from_cache(
|
||||
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
|
||||
resume_size = 0
|
||||
|
||||
if etag is not None and (not os.path.exists(cache_path) or force_download):
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info(
|
||||
"%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name
|
||||
)
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||
|
||||
# GET file object
|
||||
if url.startswith("s3://"):
|
||||
if resume_download:
|
||||
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
|
||||
s3_get(url, temp_file, proxies=proxies)
|
||||
else:
|
||||
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||
# GET file object
|
||||
if url.startswith("s3://"):
|
||||
if resume_download:
|
||||
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
|
||||
s3_get(url, temp_file, proxies=proxies)
|
||||
else:
|
||||
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||
|
||||
logger.info("storing %s in cache at %s", url, cache_path)
|
||||
os.rename(temp_file.name, cache_path)
|
||||
logger.info("storing %s in cache at %s", url, cache_path)
|
||||
os.rename(temp_file.name, cache_path)
|
||||
|
||||
logger.info("creating metadata file for %s", cache_path)
|
||||
meta = {"url": url, "etag": etag}
|
||||
meta_path = cache_path + ".json"
|
||||
with open(meta_path, "w") as meta_file:
|
||||
json.dump(meta, meta_file)
|
||||
logger.info("creating metadata file for %s", cache_path)
|
||||
meta = {"url": url, "etag": etag}
|
||||
meta_path = cache_path + ".json"
|
||||
with open(meta_path, "w") as meta_file:
|
||||
json.dump(meta, meta_file)
|
||||
|
||||
return cache_path
|
||||
|
||||
@@ -264,7 +264,7 @@ class PreTrainedTokenizer(object):
|
||||
- a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||
- a string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
|
||||
- a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
|
||||
- (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
|
||||
- (not applicable to all derived classes, deprecated) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
|
||||
|
||||
cache_dir: (`optional`) string:
|
||||
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
|
||||
@@ -331,57 +331,42 @@ class PreTrainedTokenizer(object):
|
||||
# Get the vocabulary from local files
|
||||
logger.info(
|
||||
"Model name '{}' not found in model shortcut name list ({}). "
|
||||
"Assuming '{}' is a path or url to a directory containing tokenizer files.".format(
|
||||
"Assuming '{}' is a path, a model identifier, or url to a directory containing tokenizer files.".format(
|
||||
pretrained_model_name_or_path, ", ".join(s3_models), pretrained_model_name_or_path
|
||||
)
|
||||
)
|
||||
|
||||
# Look for the tokenizer main vocabulary files
|
||||
for file_id, file_name in cls.vocab_files_names.items():
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
# If a directory is provided we look for the standard filenames
|
||||
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
|
||||
if not os.path.exists(full_file_name):
|
||||
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
|
||||
full_file_name = None
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
# If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file)
|
||||
full_file_name = pretrained_model_name_or_path
|
||||
else:
|
||||
full_file_name = hf_bucket_url(pretrained_model_name_or_path, postfix=file_name)
|
||||
|
||||
vocab_files[file_id] = full_file_name
|
||||
|
||||
# Look for the additional tokens files
|
||||
additional_files_names = {
|
||||
"added_tokens_file": ADDED_TOKENS_FILE,
|
||||
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
|
||||
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
|
||||
}
|
||||
|
||||
# If a path to a file was provided, get the parent directory
|
||||
saved_directory = pretrained_model_name_or_path
|
||||
if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
|
||||
saved_directory = os.path.dirname(saved_directory)
|
||||
|
||||
for file_id, file_name in additional_files_names.items():
|
||||
full_file_name = os.path.join(saved_directory, file_name)
|
||||
if not os.path.exists(full_file_name):
|
||||
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
|
||||
full_file_name = None
|
||||
vocab_files[file_id] = full_file_name
|
||||
|
||||
if all(full_file_name is None for full_file_name in vocab_files.values()):
|
||||
raise EnvironmentError(
|
||||
"Model name '{}' was not found in tokenizers model name list ({}). "
|
||||
"We assumed '{}' was a path or url to a directory containing vocabulary files "
|
||||
"named {} but couldn't find such vocabulary files at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
", ".join(s3_models),
|
||||
pretrained_model_name_or_path,
|
||||
list(cls.vocab_files_names.values()),
|
||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
if len(cls.vocab_files_names) > 1:
|
||||
raise ValueError(
|
||||
"Calling {}.from_pretrained() with the path to a single file or url is not supported."
|
||||
"Use a model identifier or the path to a directory instead.".format(cls.__name__)
|
||||
)
|
||||
logger.warning(
|
||||
"Calling {}.from_pretrained() with the path to a single file or url is deprecated".format(
|
||||
cls.__name__
|
||||
)
|
||||
)
|
||||
file_id = list(cls.vocab_files_names.keys())[0]
|
||||
vocab_files[file_id] = pretrained_model_name_or_path
|
||||
else:
|
||||
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
|
||||
additional_files_names = {
|
||||
"added_tokens_file": ADDED_TOKENS_FILE,
|
||||
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
|
||||
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
|
||||
}
|
||||
# Look for the tokenizer main vocabulary files + the additional tokens files
|
||||
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
|
||||
if not os.path.exists(full_file_name):
|
||||
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
|
||||
full_file_name = None
|
||||
else:
|
||||
full_file_name = hf_bucket_url(pretrained_model_name_or_path, postfix=file_name)
|
||||
|
||||
vocab_files[file_id] = full_file_name
|
||||
|
||||
# Get files from url, cache, or disk depending on the case
|
||||
try:
|
||||
@@ -414,6 +399,18 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
|
||||
raise EnvironmentError(
|
||||
"Model name '{}' was not found in tokenizers model name list ({}). "
|
||||
"We assumed '{}' was a path, a model identifier, or url to a directory containing vocabulary files "
|
||||
"named {} but couldn't find such vocabulary files at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
", ".join(s3_models),
|
||||
pretrained_model_name_or_path,
|
||||
list(cls.vocab_files_names.values()),
|
||||
)
|
||||
)
|
||||
|
||||
for file_id, file_path in vocab_files.items():
|
||||
if file_path == resolved_vocab_files[file_id]:
|
||||
logger.info("loading file {}".format(file_path))
|
||||
|
||||
@@ -56,3 +56,17 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
tokenizer = AutoTokenizer.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
|
||||
self.assertIsInstance(tokenizer, RobertaTokenizer)
|
||||
self.assertEqual(len(tokenizer), 20)
|
||||
|
||||
def test_tokenizer_identifier_with_correct_config(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for tokenizer_class in [BertTokenizer, AutoTokenizer]:
|
||||
tokenizer = tokenizer_class.from_pretrained("wietsedv/bert-base-dutch-cased")
|
||||
self.assertIsInstance(tokenizer, BertTokenizer)
|
||||
self.assertEqual(tokenizer.basic_tokenizer.do_lower_case, False)
|
||||
self.assertEqual(tokenizer.max_len, 512)
|
||||
|
||||
def test_tokenizer_identifier_non_existent(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for tokenizer_class in [BertTokenizer, AutoTokenizer]:
|
||||
with self.assertRaises(EnvironmentError):
|
||||
_ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists")
|
||||
|
||||
Reference in New Issue
Block a user