Partial local tokenizer load (#9807)

* Allow partial loading of a cached tokenizer

* Warning > Info

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Raise error if not local_files_only

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Lysandre Debut
2021-01-28 09:29:12 +01:00
committed by GitHub
parent 25fcb5c171
commit 6cb0a6f01a
2 changed files with 27 additions and 10 deletions

View File

@@ -1239,7 +1239,7 @@ def get_from_cache(
# the models might've been found if local_files_only=False # the models might've been found if local_files_only=False
# Notify the user about that # Notify the user about that
if local_files_only: if local_files_only:
raise ValueError( raise FileNotFoundError(
"Cannot find the requested files in the cached path and outgoing traffic has been" "Cannot find the requested files in the cached path and outgoing traffic has been"
" disabled. To enable model look-ups and downloads online, set 'local_files_only'" " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
" to False." " to False."

View File

@@ -1730,20 +1730,28 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
# Get files from url, cache, or disk depending on the case # Get files from url, cache, or disk depending on the case
resolved_vocab_files = {} resolved_vocab_files = {}
unresolved_files = []
for file_id, file_path in vocab_files.items(): for file_id, file_path in vocab_files.items():
if file_path is None: if file_path is None:
resolved_vocab_files[file_id] = None resolved_vocab_files[file_id] = None
else: else:
try: try:
resolved_vocab_files[file_id] = cached_path( try:
file_path, resolved_vocab_files[file_id] = cached_path(
cache_dir=cache_dir, file_path,
force_download=force_download, cache_dir=cache_dir,
proxies=proxies, force_download=force_download,
resume_download=resume_download, proxies=proxies,
local_files_only=local_files_only, resume_download=resume_download,
use_auth_token=use_auth_token, local_files_only=local_files_only,
) use_auth_token=use_auth_token,
)
except FileNotFoundError as error:
if local_files_only:
unresolved_files.append(file_id)
else:
raise error
except requests.exceptions.HTTPError as err: except requests.exceptions.HTTPError as err:
if "404 Client Error" in str(err): if "404 Client Error" in str(err):
logger.debug(err) logger.debug(err)
@@ -1751,6 +1759,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
else: else:
raise err raise err
if len(unresolved_files) > 0:
logger.info(
f"Can't load following files from cache: {unresolved_files} and cannot check if these "
"files are necessary for the tokenizer to operate."
)
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()): if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
msg = ( msg = (
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
@@ -1760,6 +1774,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
raise EnvironmentError(msg) raise EnvironmentError(msg)
for file_id, file_path in vocab_files.items(): for file_id, file_path in vocab_files.items():
if file_id not in resolved_vocab_files:
continue
if file_path == resolved_vocab_files[file_id]: if file_path == resolved_vocab_files[file_id]:
logger.info("loading file {}".format(file_path)) logger.info("loading file {}".format(file_path))
else: else: