Avoid using get_list_of_files (#15287)
* Avoid using get_list_of_files in config * Wip, change tokenizer file getter * Remove call in tokenizer files * Remove last call to get_list_model_files * Better tests * Unit tests for new function * Document bad API
This commit is contained in:
@@ -50,7 +50,7 @@ from .file_utils import (
|
||||
add_end_docstrings,
|
||||
cached_path,
|
||||
copy_func,
|
||||
get_list_of_files,
|
||||
get_file_from_repo,
|
||||
hf_bucket_url,
|
||||
is_flax_available,
|
||||
is_offline_mode,
|
||||
@@ -1649,12 +1649,26 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
vocab_files[file_id] = pretrained_model_name_or_path
|
||||
else:
|
||||
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
|
||||
fast_tokenizer_file = get_fast_tokenizer_file(
|
||||
|
||||
# Try to get the tokenizer config to see if there are versioned tokenizer files.
|
||||
fast_tokenizer_file = FULL_TOKENIZER_FILE
|
||||
resolved_config_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
TOKENIZER_CONFIG_FILE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
if resolved_config_file is not None:
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
tokenizer_config = json.load(reader)
|
||||
if "fast_tokenizer_files" in tokenizer_config:
|
||||
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
|
||||
|
||||
additional_files_names = {
|
||||
"added_tokens_file": ADDED_TOKENS_FILE,
|
||||
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
|
||||
@@ -3495,41 +3509,18 @@ For a more complete example, see the implementation of `prepare_seq2seq_batch`.
|
||||
return model_inputs
|
||||
|
||||
|
||||
def get_fast_tokenizer_file(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
revision: Optional[str] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
local_files_only: bool = False,
|
||||
) -> str:
|
||||
def get_fast_tokenizer_file(tokenization_files: List[str]) -> str:
|
||||
"""
|
||||
Get the tokenizer file to use for this version of transformers.
|
||||
Get the tokenization file to use for this version of transformers.
|
||||
|
||||
Args:
|
||||
path_or_repo (`str` or `os.PathLike`):
|
||||
Can be either the id of a repo on huggingface.co or a path to a *directory*.
|
||||
revision(`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only rely on local files and not to attempt to download any files.
|
||||
tokenization_files (`List[str]`): The list of available configuration files.
|
||||
|
||||
Returns:
|
||||
`str`: The tokenizer file to use.
|
||||
`str`: The tokenization file to use.
|
||||
"""
|
||||
# Inspect all files from the repo/folder.
|
||||
try:
|
||||
all_files = get_list_of_files(
|
||||
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
|
||||
)
|
||||
except Exception:
|
||||
return FULL_TOKENIZER_FILE
|
||||
|
||||
tokenizer_files_map = {}
|
||||
for file_name in all_files:
|
||||
for file_name in tokenization_files:
|
||||
search = _re_tokenizer_file.search(file_name)
|
||||
if search is not None:
|
||||
v = search.groups()[0]
|
||||
|
||||
Reference in New Issue
Block a user