Avoid using get_list_of_files (#15287)

* Avoid using get_list_of_files in config

* Wip, change tokenizer file getter

* Remove call in tokenizer files

* Remove last call to get_list_model_files

* Better tests

* Unit tests for new function

* Document bad API
This commit is contained in:
Sylvain Gugger
2022-01-25 09:41:21 -05:00
committed by GitHub
parent e65bfc0971
commit e695470794
8 changed files with 232 additions and 134 deletions

View File

@@ -50,7 +50,7 @@ from .file_utils import (
add_end_docstrings,
cached_path,
copy_func,
get_list_of_files,
get_file_from_repo,
hf_bucket_url,
is_flax_available,
is_offline_mode,
@@ -1649,12 +1649,26 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
vocab_files[file_id] = pretrained_model_name_or_path
else:
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
fast_tokenizer_file = get_fast_tokenizer_file(
# Try to get the tokenizer config to see if there are versioned tokenizer files.
fast_tokenizer_file = FULL_TOKENIZER_FILE
resolved_config_file = get_file_from_repo(
pretrained_model_name_or_path,
revision=revision,
TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
)
if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader:
tokenizer_config = json.load(reader)
if "fast_tokenizer_files" in tokenizer_config:
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
additional_files_names = {
"added_tokens_file": ADDED_TOKENS_FILE,
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
@@ -3495,41 +3509,18 @@ For a more complete example, see the implementation of `prepare_seq2seq_batch`.
return model_inputs
def get_fast_tokenizer_file(
path_or_repo: Union[str, os.PathLike],
revision: Optional[str] = None,
use_auth_token: Optional[Union[bool, str]] = None,
local_files_only: bool = False,
) -> str:
def get_fast_tokenizer_file(tokenization_files: List[str]) -> str:
"""
Get the tokenizer file to use for this version of transformers.
Get the tokenization file to use for this version of transformers.
Args:
path_or_repo (`str` or `os.PathLike`):
Can be either the id of a repo on huggingface.co or a path to a *directory*.
revision(`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only rely on local files and not to attempt to download any files.
tokenization_files (`List[str]`): The list of available configuration files.
Returns:
`str`: The tokenizer file to use.
`str`: The tokenization file to use.
"""
# Inspect all files from the repo/folder.
try:
all_files = get_list_of_files(
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
)
except Exception:
return FULL_TOKENIZER_FILE
tokenizer_files_map = {}
for file_name in all_files:
for file_name in tokenization_files:
search = _re_tokenizer_file.search(file_name)
if search is not None:
v = search.groups()[0]