From 39db2f3c192b830a82fcfc983a95cf993a7508e8 Mon Sep 17 00:00:00 2001 From: Bram Vanroy Date: Tue, 24 Aug 2021 09:05:33 +0200 Subject: [PATCH] Allow local_files_only for fast pretrained tokenizers (#13225) * allow local_files_only for fast pretrained tokenizers * make style --- src/transformers/file_utils.py | 5 ++++- src/transformers/tokenization_utils_base.py | 14 ++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index a395ec275f..0b31c17bd5 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -1654,6 +1654,7 @@ def get_list_of_files( path_or_repo: Union[str, os.PathLike], revision: Optional[str] = None, use_auth_token: Optional[Union[bool, str]] = None, + local_files_only: bool = False, ) -> List[str]: """ Gets the list of files inside :obj:`path_or_repo`. @@ -1668,6 +1669,8 @@ def get_list_of_files( use_auth_token (:obj:`str` or `bool`, `optional`): The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). + local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to only rely on local files and not to attempt to download any files. Returns: :obj:`List[str]`: The list of files available in :obj:`path_or_repo`. @@ -1681,7 +1684,7 @@ def get_list_of_files( return list_of_files # Can't grab the files if we are on offline mode. - if is_offline_mode(): + if is_offline_mode() or local_files_only: return [] # Otherwise we grab the token and use the model_info method. diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 5bd7c06540..e57299cf77 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1566,6 +1566,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): use_auth_token (:obj:`str` or `bool`, `optional`): The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). + local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to only rely on local files and not to attempt to download any files. revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any @@ -1645,7 +1647,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): else: # At this point pretrained_model_name_or_path is either a directory or a model identifier name fast_tokenizer_file = get_fast_tokenizer_file( - pretrained_model_name_or_path, revision=revision, use_auth_token=use_auth_token + pretrained_model_name_or_path, + revision=revision, + use_auth_token=use_auth_token, + local_files_only=local_files_only, ) additional_files_names = { "added_tokens_file": ADDED_TOKENS_FILE, @@ -3389,6 +3394,7 @@ def get_fast_tokenizer_file( path_or_repo: Union[str, os.PathLike], revision: Optional[str] = None, use_auth_token: Optional[Union[bool, str]] = None, + local_files_only: bool = False, ) -> str: """ Get the tokenizer file to use for this version of transformers. @@ -3403,12 +3409,16 @@ def get_fast_tokenizer_file( use_auth_token (:obj:`str` or `bool`, `optional`): The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). + local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to only rely on local files and not to attempt to download any files. Returns: :obj:`str`: The tokenizer file to use. """ # Inspect all files from the repo/folder. - all_files = get_list_of_files(path_or_repo, revision=revision, use_auth_token=use_auth_token) + all_files = get_list_of_files( + path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only + ) tokenizer_files_map = {} for file_name in all_files: search = _re_tokenizer_file.search(file_name)