From 23a2cea8cb95864ddb7e7e80e126e4f083640882 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 15 Jan 2020 21:07:26 +0000 Subject: [PATCH] Tokenizer.from_pretrained: fetch all possible files remotely --- src/transformers/configuration_utils.py | 4 +- src/transformers/file_utils.py | 83 +++++++++++++---------- src/transformers/tokenization_utils.py | 89 ++++++++++++------------- tests/test_tokenization_auto.py | 14 ++++ 4 files changed, 109 insertions(+), 81 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index f980e5e14c..62fed9ef04 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -200,6 +200,8 @@ class PretrainedConfig(object): resume_download=resume_download, ) # Load config dict + if resolved_config_file is None: + raise EnvironmentError config_dict = cls._dict_from_json_file(resolved_config_file) except EnvironmentError: @@ -210,7 +212,7 @@ class PretrainedConfig(object): else: msg = ( "Model name '{}' was not found in model name list. " - "We assumed '{}' was a path or url to a configuration file named {} or " + "We assumed '{}' was a path, a model identifier, or url to a configuration file named {} or " "a directory containing such a file but couldn't find any such file at this path or url.".format( pretrained_model_name_or_path, config_file, CONFIG_NAME, ) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 259c1c5643..3b90dca7c2 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -14,6 +14,7 @@ import tempfile from contextlib import contextmanager from functools import partial, wraps from hashlib import sha256 +from typing import Optional from urllib.parse import urlparse import boto3 @@ -122,7 +123,7 @@ def is_remote_url(url_or_filename): return parsed.scheme in ("http", "https", "s3") -def hf_bucket_url(identifier, postfix=None, cdn=False): +def hf_bucket_url(identifier, postfix=None, cdn=False) -> str: endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX if postfix is None: return "/".join((endpoint, identifier)) @@ -182,7 +183,7 @@ def filename_to_url(filename, cache_dir=None): def cached_path( url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None -): +) -> Optional[str]: """ Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file and cache it, and @@ -193,6 +194,10 @@ def cached_path( force_download: if True, re-dowload the file even if it's already cached in the cache dir. resume_download: if True, resume the download if incompletly recieved file is found. user_agent: Optional string or dict that will be appended to the user-agent on remote requests. + + Return: + None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). + Local path (string) otherwise """ if cache_dir is None: cache_dir = TRANSFORMERS_CACHE @@ -306,10 +311,14 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None): def get_from_cache( url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None -): +) -> Optional[str]: """ - Given a URL, look for the corresponding dataset in the local cache. + Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the path to the cached file. + + Return: + None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). + Local path (string) otherwise """ if cache_dir is None: cache_dir = TRANSFORMERS_CACHE @@ -336,16 +345,25 @@ def get_from_cache( # get cache path to put the file cache_path = os.path.join(cache_dir, filename) - # If we don't have a connection (etag is None) and can't identify the file + # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible. # try to get the last downloaded one - if not os.path.exists(cache_path) and etag is None: - matching_files = [ - file - for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*") - if not file.endswith(".json") and not file.endswith(".lock") - ] - if matching_files: - cache_path = os.path.join(cache_dir, matching_files[-1]) + if etag is None: + if os.path.exists(cache_path): + return cache_path + else: + matching_files = [ + file + for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*") + if not file.endswith(".json") and not file.endswith(".lock") + ] + if len(matching_files) > 0: + return os.path.join(cache_dir, matching_files[-1]) + else: + return None + + # From now on, etag is not None. + if os.path.exists(cache_path) and not force_download: + return cache_path # Prevent parallel downloads of the same file with a lock. lock_path = cache_path + ".lock" @@ -368,29 +386,26 @@ def get_from_cache( temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) resume_size = 0 - if etag is not None and (not os.path.exists(cache_path) or force_download): - # Download to temporary file, then copy to cache dir once finished. - # Otherwise you get corrupt cache entries if the download gets interrupted. - with temp_file_manager() as temp_file: - logger.info( - "%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name - ) + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with temp_file_manager() as temp_file: + logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) - # GET file object - if url.startswith("s3://"): - if resume_download: - logger.warn('Warning: resumable downloads are not implemented for "s3://" urls') - s3_get(url, temp_file, proxies=proxies) - else: - http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) + # GET file object + if url.startswith("s3://"): + if resume_download: + logger.warn('Warning: resumable downloads are not implemented for "s3://" urls') + s3_get(url, temp_file, proxies=proxies) + else: + http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) - logger.info("storing %s in cache at %s", url, cache_path) - os.rename(temp_file.name, cache_path) + logger.info("storing %s in cache at %s", url, cache_path) + os.rename(temp_file.name, cache_path) - logger.info("creating metadata file for %s", cache_path) - meta = {"url": url, "etag": etag} - meta_path = cache_path + ".json" - with open(meta_path, "w") as meta_file: - json.dump(meta, meta_file) + logger.info("creating metadata file for %s", cache_path) + meta = {"url": url, "etag": etag} + meta_path = cache_path + ".json" + with open(meta_path, "w") as meta_file: + json.dump(meta, meta_file) return cache_path diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 7a42eba46e..b4545a2d44 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -264,7 +264,7 @@ class PreTrainedTokenizer(object): - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``. - - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``. + - (not applicable to all derived classes, deprecated) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``. cache_dir: (`optional`) string: Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used. @@ -331,57 +331,42 @@ class PreTrainedTokenizer(object): # Get the vocabulary from local files logger.info( "Model name '{}' not found in model shortcut name list ({}). " - "Assuming '{}' is a path or url to a directory containing tokenizer files.".format( + "Assuming '{}' is a path, a model identifier, or url to a directory containing tokenizer files.".format( pretrained_model_name_or_path, ", ".join(s3_models), pretrained_model_name_or_path ) ) - # Look for the tokenizer main vocabulary files - for file_id, file_name in cls.vocab_files_names.items(): - if os.path.isdir(pretrained_model_name_or_path): - # If a directory is provided we look for the standard filenames - full_file_name = os.path.join(pretrained_model_name_or_path, file_name) - if not os.path.exists(full_file_name): - logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) - full_file_name = None - elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): - # If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file) - full_file_name = pretrained_model_name_or_path - else: - full_file_name = hf_bucket_url(pretrained_model_name_or_path, postfix=file_name) - - vocab_files[file_id] = full_file_name - - # Look for the additional tokens files - additional_files_names = { - "added_tokens_file": ADDED_TOKENS_FILE, - "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, - "tokenizer_config_file": TOKENIZER_CONFIG_FILE, - } - - # If a path to a file was provided, get the parent directory - saved_directory = pretrained_model_name_or_path - if os.path.exists(saved_directory) and not os.path.isdir(saved_directory): - saved_directory = os.path.dirname(saved_directory) - - for file_id, file_name in additional_files_names.items(): - full_file_name = os.path.join(saved_directory, file_name) - if not os.path.exists(full_file_name): - logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) - full_file_name = None - vocab_files[file_id] = full_file_name - - if all(full_file_name is None for full_file_name in vocab_files.values()): - raise EnvironmentError( - "Model name '{}' was not found in tokenizers model name list ({}). " - "We assumed '{}' was a path or url to a directory containing vocabulary files " - "named {} but couldn't find such vocabulary files at this path or url.".format( - pretrained_model_name_or_path, - ", ".join(s3_models), - pretrained_model_name_or_path, - list(cls.vocab_files_names.values()), + if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): + if len(cls.vocab_files_names) > 1: + raise ValueError( + "Calling {}.from_pretrained() with the path to a single file or url is not supported." + "Use a model identifier or the path to a directory instead.".format(cls.__name__) + ) + logger.warning( + "Calling {}.from_pretrained() with the path to a single file or url is deprecated".format( + cls.__name__ ) ) + file_id = list(cls.vocab_files_names.keys())[0] + vocab_files[file_id] = pretrained_model_name_or_path + else: + # At this point pretrained_model_name_or_path is either a directory or a model identifier name + additional_files_names = { + "added_tokens_file": ADDED_TOKENS_FILE, + "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, + "tokenizer_config_file": TOKENIZER_CONFIG_FILE, + } + # Look for the tokenizer main vocabulary files + the additional tokens files + for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items(): + if os.path.isdir(pretrained_model_name_or_path): + full_file_name = os.path.join(pretrained_model_name_or_path, file_name) + if not os.path.exists(full_file_name): + logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) + full_file_name = None + else: + full_file_name = hf_bucket_url(pretrained_model_name_or_path, postfix=file_name) + + vocab_files[file_id] = full_file_name # Get files from url, cache, or disk depending on the case try: @@ -414,6 +399,18 @@ class PreTrainedTokenizer(object): raise EnvironmentError(msg) + if all(full_file_name is None for full_file_name in resolved_vocab_files.values()): + raise EnvironmentError( + "Model name '{}' was not found in tokenizers model name list ({}). " + "We assumed '{}' was a path, a model identifier, or url to a directory containing vocabulary files " + "named {} but couldn't find such vocabulary files at this path or url.".format( + pretrained_model_name_or_path, + ", ".join(s3_models), + pretrained_model_name_or_path, + list(cls.vocab_files_names.values()), + ) + ) + for file_id, file_path in vocab_files.items(): if file_path == resolved_vocab_files[file_id]: logger.info("loading file {}".format(file_path)) diff --git a/tests/test_tokenization_auto.py b/tests/test_tokenization_auto.py index cd7187c4f2..261c064a40 100644 --- a/tests/test_tokenization_auto.py +++ b/tests/test_tokenization_auto.py @@ -56,3 +56,17 @@ class AutoTokenizerTest(unittest.TestCase): tokenizer = AutoTokenizer.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER) self.assertIsInstance(tokenizer, RobertaTokenizer) self.assertEqual(len(tokenizer), 20) + + def test_tokenizer_identifier_with_correct_config(self): + logging.basicConfig(level=logging.INFO) + for tokenizer_class in [BertTokenizer, AutoTokenizer]: + tokenizer = tokenizer_class.from_pretrained("wietsedv/bert-base-dutch-cased") + self.assertIsInstance(tokenizer, BertTokenizer) + self.assertEqual(tokenizer.basic_tokenizer.do_lower_case, False) + self.assertEqual(tokenizer.max_len, 512) + + def test_tokenizer_identifier_non_existent(self): + logging.basicConfig(level=logging.INFO) + for tokenizer_class in [BertTokenizer, AutoTokenizer]: + with self.assertRaises(EnvironmentError): + _ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists")