Tokenizer.from_pretrained: fetch all possible files remotely
This commit is contained in:
@@ -200,6 +200,8 @@ class PretrainedConfig(object):
|
|||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
)
|
)
|
||||||
# Load config dict
|
# Load config dict
|
||||||
|
if resolved_config_file is None:
|
||||||
|
raise EnvironmentError
|
||||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||||
|
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
@@ -210,7 +212,7 @@ class PretrainedConfig(object):
|
|||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
"Model name '{}' was not found in model name list. "
|
"Model name '{}' was not found in model name list. "
|
||||||
"We assumed '{}' was a path or url to a configuration file named {} or "
|
"We assumed '{}' was a path, a model identifier, or url to a configuration file named {} or "
|
||||||
"a directory containing such a file but couldn't find any such file at this path or url.".format(
|
"a directory containing such a file but couldn't find any such file at this path or url.".format(
|
||||||
pretrained_model_name_or_path, config_file, CONFIG_NAME,
|
pretrained_model_name_or_path, config_file, CONFIG_NAME,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import tempfile
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
from typing import Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
@@ -122,7 +123,7 @@ def is_remote_url(url_or_filename):
|
|||||||
return parsed.scheme in ("http", "https", "s3")
|
return parsed.scheme in ("http", "https", "s3")
|
||||||
|
|
||||||
|
|
||||||
def hf_bucket_url(identifier, postfix=None, cdn=False):
|
def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
|
||||||
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
|
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
|
||||||
if postfix is None:
|
if postfix is None:
|
||||||
return "/".join((endpoint, identifier))
|
return "/".join((endpoint, identifier))
|
||||||
@@ -182,7 +183,7 @@ def filename_to_url(filename, cache_dir=None):
|
|||||||
|
|
||||||
def cached_path(
|
def cached_path(
|
||||||
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
|
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
|
||||||
):
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Given something that might be a URL (or might be a local path),
|
Given something that might be a URL (or might be a local path),
|
||||||
determine which. If it's a URL, download the file and cache it, and
|
determine which. If it's a URL, download the file and cache it, and
|
||||||
@@ -193,6 +194,10 @@ def cached_path(
|
|||||||
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
|
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
|
||||||
resume_download: if True, resume the download if incompletly recieved file is found.
|
resume_download: if True, resume the download if incompletly recieved file is found.
|
||||||
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
|
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
||||||
|
Local path (string) otherwise
|
||||||
"""
|
"""
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = TRANSFORMERS_CACHE
|
cache_dir = TRANSFORMERS_CACHE
|
||||||
@@ -306,10 +311,14 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
|
|||||||
|
|
||||||
def get_from_cache(
|
def get_from_cache(
|
||||||
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None
|
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None
|
||||||
):
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Given a URL, look for the corresponding dataset in the local cache.
|
Given a URL, look for the corresponding file in the local cache.
|
||||||
If it's not there, download it. Then return the path to the cached file.
|
If it's not there, download it. Then return the path to the cached file.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
||||||
|
Local path (string) otherwise
|
||||||
"""
|
"""
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = TRANSFORMERS_CACHE
|
cache_dir = TRANSFORMERS_CACHE
|
||||||
@@ -336,16 +345,25 @@ def get_from_cache(
|
|||||||
# get cache path to put the file
|
# get cache path to put the file
|
||||||
cache_path = os.path.join(cache_dir, filename)
|
cache_path = os.path.join(cache_dir, filename)
|
||||||
|
|
||||||
# If we don't have a connection (etag is None) and can't identify the file
|
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
|
||||||
# try to get the last downloaded one
|
# try to get the last downloaded one
|
||||||
if not os.path.exists(cache_path) and etag is None:
|
if etag is None:
|
||||||
|
if os.path.exists(cache_path):
|
||||||
|
return cache_path
|
||||||
|
else:
|
||||||
matching_files = [
|
matching_files = [
|
||||||
file
|
file
|
||||||
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
|
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
|
||||||
if not file.endswith(".json") and not file.endswith(".lock")
|
if not file.endswith(".json") and not file.endswith(".lock")
|
||||||
]
|
]
|
||||||
if matching_files:
|
if len(matching_files) > 0:
|
||||||
cache_path = os.path.join(cache_dir, matching_files[-1])
|
return os.path.join(cache_dir, matching_files[-1])
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# From now on, etag is not None.
|
||||||
|
if os.path.exists(cache_path) and not force_download:
|
||||||
|
return cache_path
|
||||||
|
|
||||||
# Prevent parallel downloads of the same file with a lock.
|
# Prevent parallel downloads of the same file with a lock.
|
||||||
lock_path = cache_path + ".lock"
|
lock_path = cache_path + ".lock"
|
||||||
@@ -368,13 +386,10 @@ def get_from_cache(
|
|||||||
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
|
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
|
||||||
resume_size = 0
|
resume_size = 0
|
||||||
|
|
||||||
if etag is not None and (not os.path.exists(cache_path) or force_download):
|
|
||||||
# Download to temporary file, then copy to cache dir once finished.
|
# Download to temporary file, then copy to cache dir once finished.
|
||||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||||
with temp_file_manager() as temp_file:
|
with temp_file_manager() as temp_file:
|
||||||
logger.info(
|
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||||
"%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name
|
|
||||||
)
|
|
||||||
|
|
||||||
# GET file object
|
# GET file object
|
||||||
if url.startswith("s3://"):
|
if url.startswith("s3://"):
|
||||||
|
|||||||
@@ -264,7 +264,7 @@ class PreTrainedTokenizer(object):
|
|||||||
- a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
|
- a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||||
- a string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
|
- a string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
|
||||||
- a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
|
- a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
|
||||||
- (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
|
- (not applicable to all derived classes, deprecated) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
|
||||||
|
|
||||||
cache_dir: (`optional`) string:
|
cache_dir: (`optional`) string:
|
||||||
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
|
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
|
||||||
@@ -331,57 +331,42 @@ class PreTrainedTokenizer(object):
|
|||||||
# Get the vocabulary from local files
|
# Get the vocabulary from local files
|
||||||
logger.info(
|
logger.info(
|
||||||
"Model name '{}' not found in model shortcut name list ({}). "
|
"Model name '{}' not found in model shortcut name list ({}). "
|
||||||
"Assuming '{}' is a path or url to a directory containing tokenizer files.".format(
|
"Assuming '{}' is a path, a model identifier, or url to a directory containing tokenizer files.".format(
|
||||||
pretrained_model_name_or_path, ", ".join(s3_models), pretrained_model_name_or_path
|
pretrained_model_name_or_path, ", ".join(s3_models), pretrained_model_name_or_path
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Look for the tokenizer main vocabulary files
|
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
for file_id, file_name in cls.vocab_files_names.items():
|
if len(cls.vocab_files_names) > 1:
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
raise ValueError(
|
||||||
# If a directory is provided we look for the standard filenames
|
"Calling {}.from_pretrained() with the path to a single file or url is not supported."
|
||||||
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
|
"Use a model identifier or the path to a directory instead.".format(cls.__name__)
|
||||||
if not os.path.exists(full_file_name):
|
)
|
||||||
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
|
logger.warning(
|
||||||
full_file_name = None
|
"Calling {}.from_pretrained() with the path to a single file or url is deprecated".format(
|
||||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
cls.__name__
|
||||||
# If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file)
|
)
|
||||||
full_file_name = pretrained_model_name_or_path
|
)
|
||||||
|
file_id = list(cls.vocab_files_names.keys())[0]
|
||||||
|
vocab_files[file_id] = pretrained_model_name_or_path
|
||||||
else:
|
else:
|
||||||
full_file_name = hf_bucket_url(pretrained_model_name_or_path, postfix=file_name)
|
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
|
||||||
|
|
||||||
vocab_files[file_id] = full_file_name
|
|
||||||
|
|
||||||
# Look for the additional tokens files
|
|
||||||
additional_files_names = {
|
additional_files_names = {
|
||||||
"added_tokens_file": ADDED_TOKENS_FILE,
|
"added_tokens_file": ADDED_TOKENS_FILE,
|
||||||
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
|
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
|
||||||
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
|
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
|
||||||
}
|
}
|
||||||
|
# Look for the tokenizer main vocabulary files + the additional tokens files
|
||||||
# If a path to a file was provided, get the parent directory
|
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():
|
||||||
saved_directory = pretrained_model_name_or_path
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
|
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
|
||||||
saved_directory = os.path.dirname(saved_directory)
|
|
||||||
|
|
||||||
for file_id, file_name in additional_files_names.items():
|
|
||||||
full_file_name = os.path.join(saved_directory, file_name)
|
|
||||||
if not os.path.exists(full_file_name):
|
if not os.path.exists(full_file_name):
|
||||||
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
|
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
|
||||||
full_file_name = None
|
full_file_name = None
|
||||||
vocab_files[file_id] = full_file_name
|
else:
|
||||||
|
full_file_name = hf_bucket_url(pretrained_model_name_or_path, postfix=file_name)
|
||||||
|
|
||||||
if all(full_file_name is None for full_file_name in vocab_files.values()):
|
vocab_files[file_id] = full_file_name
|
||||||
raise EnvironmentError(
|
|
||||||
"Model name '{}' was not found in tokenizers model name list ({}). "
|
|
||||||
"We assumed '{}' was a path or url to a directory containing vocabulary files "
|
|
||||||
"named {} but couldn't find such vocabulary files at this path or url.".format(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
", ".join(s3_models),
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
list(cls.vocab_files_names.values()),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get files from url, cache, or disk depending on the case
|
# Get files from url, cache, or disk depending on the case
|
||||||
try:
|
try:
|
||||||
@@ -414,6 +399,18 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
raise EnvironmentError(msg)
|
raise EnvironmentError(msg)
|
||||||
|
|
||||||
|
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
|
||||||
|
raise EnvironmentError(
|
||||||
|
"Model name '{}' was not found in tokenizers model name list ({}). "
|
||||||
|
"We assumed '{}' was a path, a model identifier, or url to a directory containing vocabulary files "
|
||||||
|
"named {} but couldn't find such vocabulary files at this path or url.".format(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
", ".join(s3_models),
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
list(cls.vocab_files_names.values()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for file_id, file_path in vocab_files.items():
|
for file_id, file_path in vocab_files.items():
|
||||||
if file_path == resolved_vocab_files[file_id]:
|
if file_path == resolved_vocab_files[file_id]:
|
||||||
logger.info("loading file {}".format(file_path))
|
logger.info("loading file {}".format(file_path))
|
||||||
|
|||||||
@@ -56,3 +56,17 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
|
tokenizer = AutoTokenizer.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
|
||||||
self.assertIsInstance(tokenizer, RobertaTokenizer)
|
self.assertIsInstance(tokenizer, RobertaTokenizer)
|
||||||
self.assertEqual(len(tokenizer), 20)
|
self.assertEqual(len(tokenizer), 20)
|
||||||
|
|
||||||
|
def test_tokenizer_identifier_with_correct_config(self):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
for tokenizer_class in [BertTokenizer, AutoTokenizer]:
|
||||||
|
tokenizer = tokenizer_class.from_pretrained("wietsedv/bert-base-dutch-cased")
|
||||||
|
self.assertIsInstance(tokenizer, BertTokenizer)
|
||||||
|
self.assertEqual(tokenizer.basic_tokenizer.do_lower_case, False)
|
||||||
|
self.assertEqual(tokenizer.max_len, 512)
|
||||||
|
|
||||||
|
def test_tokenizer_identifier_non_existent(self):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
for tokenizer_class in [BertTokenizer, AutoTokenizer]:
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists")
|
||||||
|
|||||||
Reference in New Issue
Block a user