From 6cb0a6f01a8e459e618cd123205d5db2f8afb0b1 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 28 Jan 2021 09:29:12 +0100 Subject: [PATCH] Partial local tokenizer load (#9807) * Allow partial loading of a cached tokenizer * Warning > Info * Update src/transformers/tokenization_utils_base.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Raise error if not local_files_only Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/file_utils.py | 2 +- src/transformers/tokenization_utils_base.py | 35 +++++++++++++++------ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index fd22962f0d..fc4f73b686 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -1239,7 +1239,7 @@ def get_from_cache( # the models might've been found if local_files_only=False # Notify the user about that if local_files_only: - raise ValueError( + raise FileNotFoundError( "Cannot find the requested files in the cached path and outgoing traffic has been" " disabled. To enable model look-ups and downloads online, set 'local_files_only'" " to False." diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index f3161c710b..8544547d82 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1730,20 +1730,28 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): # Get files from url, cache, or disk depending on the case resolved_vocab_files = {} + unresolved_files = [] for file_id, file_path in vocab_files.items(): if file_path is None: resolved_vocab_files[file_id] = None else: try: - resolved_vocab_files[file_id] = cached_path( - file_path, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - ) + try: + resolved_vocab_files[file_id] = cached_path( + file_path, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + ) + except FileNotFoundError as error: + if local_files_only: + unresolved_files.append(file_id) + else: + raise error + except requests.exceptions.HTTPError as err: if "404 Client Error" in str(err): logger.debug(err) @@ -1751,6 +1759,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): else: raise err + if len(unresolved_files) > 0: + logger.info( + f"Can't load following files from cache: {unresolved_files} and cannot check if these " + "files are necessary for the tokenizer to operate." + ) + if all(full_file_name is None for full_file_name in resolved_vocab_files.values()): msg = ( f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n" @@ -1760,6 +1774,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): raise EnvironmentError(msg) for file_id, file_path in vocab_files.items(): + if file_id not in resolved_vocab_files: + continue + if file_path == resolved_vocab_files[file_id]: logger.info("loading file {}".format(file_path)) else: