Avoid using get_list_of_files (#15287)
* Avoid using get_list_of_files in config * Wip, change tokenizer file getter * Remove call in tokenizer files * Remove last call to get_list_model_files * Better tests * Unit tests for new function * Document bad API
This commit is contained in:
@@ -21,7 +21,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
@@ -36,7 +36,6 @@ from .file_utils import (
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
copy_func,
|
||||
get_list_of_files,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
@@ -46,7 +45,7 @@ from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
FULL_CONFIGURATION_FILE = "config.json"
|
||||
|
||||
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
||||
|
||||
|
||||
@@ -533,6 +532,23 @@ class PretrainedConfig(PushToHubMixin):
|
||||
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
|
||||
|
||||
"""
|
||||
original_kwargs = copy.deepcopy(kwargs)
|
||||
# Get config dict associated with the base config file
|
||||
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# That config file may point us toward another config file to use.
|
||||
if "configuration_files" in config_dict:
|
||||
configuration_file = get_configuration_file(config_dict["configuration_files"])
|
||||
config_dict, kwargs = cls._get_config_dict(
|
||||
pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
|
||||
)
|
||||
|
||||
return config_dict, kwargs
|
||||
|
||||
@classmethod
|
||||
def _get_config_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
@@ -555,12 +571,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
else:
|
||||
configuration_file = get_configuration_file(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
configuration_file = kwargs.get("_configuration_file", CONFIG_NAME)
|
||||
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
|
||||
@@ -840,41 +851,18 @@ class PretrainedConfig(PushToHubMixin):
|
||||
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
||||
|
||||
|
||||
def get_configuration_file(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
revision: Optional[str] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
local_files_only: bool = False,
|
||||
) -> str:
|
||||
def get_configuration_file(configuration_files: List[str]) -> str:
|
||||
"""
|
||||
Get the configuration file to use for this version of transformers.
|
||||
|
||||
Args:
|
||||
path_or_repo (`str` or `os.PathLike`):
|
||||
Can be either the id of a repo on huggingface.co or a path to a *directory*.
|
||||
revision(`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only rely on local files and not to attempt to download any files.
|
||||
configuration_files (`List[str]`): The list of available configuration files.
|
||||
|
||||
Returns:
|
||||
`str`: The configuration file to use.
|
||||
"""
|
||||
# Inspect all files from the repo/folder.
|
||||
try:
|
||||
all_files = get_list_of_files(
|
||||
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
|
||||
)
|
||||
except Exception:
|
||||
return FULL_CONFIGURATION_FILE
|
||||
|
||||
configuration_files_map = {}
|
||||
for file_name in all_files:
|
||||
for file_name in configuration_files:
|
||||
search = _re_configuration_file.search(file_name)
|
||||
if search is not None:
|
||||
v = search.groups()[0]
|
||||
@@ -882,7 +870,7 @@ def get_configuration_file(
|
||||
available_versions = sorted(configuration_files_map.keys())
|
||||
|
||||
# Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
|
||||
configuration_file = FULL_CONFIGURATION_FILE
|
||||
configuration_file = CONFIG_NAME
|
||||
transformers_version = version.parse(__version__)
|
||||
for v in available_versions:
|
||||
if version.parse(v) <= transformers_version:
|
||||
|
||||
@@ -2112,6 +2112,112 @@ def get_from_cache(
|
||||
return cache_path
|
||||
|
||||
|
||||
def get_file_from_repo(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
filename: str,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = False,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
):
|
||||
"""
|
||||
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
|
||||
|
||||
Args:
|
||||
path_or_repo (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the *model id* of a model repo on huggingface.co.
|
||||
- a path to a *directory* potentially containing the file.
|
||||
filename (`str`):
|
||||
The name of the file to locate in `path_or_repo`.
|
||||
cache_dir (`str` or `os.PathLike`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
||||
cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
||||
exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||
revision(`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will only try to load the tokenizer configuration from local files.
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `use_auth_token=True` is required when you want to use a private model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
`Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo) or `None` if the
|
||||
file does not exist.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Download a tokenizer configuration from huggingface.co and cache.
|
||||
tokenizer_config = get_file_from_repo("bert-base-uncased", "tokenizer_config.json")
|
||||
# This model does not have a tokenizer config so the result will be None.
|
||||
tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
|
||||
```"""
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
path_or_repo = str(path_or_repo)
|
||||
if os.path.isdir(path_or_repo):
|
||||
resolved_file = os.path.join(path_or_repo, filename)
|
||||
return resolved_file if os.path.isfile(resolved_file) else None
|
||||
else:
|
||||
resolved_file = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=None)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_file = cached_path(
|
||||
resolved_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{path_or_repo} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
|
||||
"pass a token having permission to this repo with `use_auth_token` or log in with "
|
||||
"`huggingface-cli login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
||||
"for this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{path_or_repo}' for available revisions."
|
||||
)
|
||||
except EnvironmentError:
|
||||
# The repo and revision exist, but the file does not or there was a connection error fetching it.
|
||||
return None
|
||||
|
||||
return resolved_file
|
||||
|
||||
|
||||
def has_file(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
filename: str,
|
||||
@@ -2184,6 +2290,12 @@ def get_list_of_files(
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only rely on local files and not to attempt to download any files.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is not optimized, so calling it a lot may result in connection errors.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of files available in `path_or_repo`.
|
||||
"""
|
||||
|
||||
@@ -14,12 +14,14 @@
|
||||
# limitations under the License.
|
||||
""" AutoProcessor class."""
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
# Build the list of all feature extractors
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...feature_extraction_utils import FeatureExtractionMixin
|
||||
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_list_of_files
|
||||
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo
|
||||
from ...tokenization_utils import TOKENIZER_CONFIG_FILE
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
@@ -29,7 +31,6 @@ from .configuration_auto import (
|
||||
model_type_to_module_name,
|
||||
replace_list_option_in_docstrings,
|
||||
)
|
||||
from .tokenization_auto import get_tokenizer_config
|
||||
|
||||
|
||||
PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
@@ -145,24 +146,29 @@ class AutoProcessor:
|
||||
kwargs["_from_auto"] = True
|
||||
|
||||
# First, let's see if we have a preprocessor config.
|
||||
# get_list_of_files only takes three of the kwargs we have, so we filter them.
|
||||
get_list_of_files_kwargs = {
|
||||
key: kwargs[key] for key in ["revision", "use_auth_token", "local_files_only"] if key in kwargs
|
||||
# Filter the kwargs for `get_file_from_repo``.
|
||||
get_file_from_repo_kwargs = {
|
||||
key: kwargs[key] for key in inspect.signature(get_file_from_repo).parameters.keys() if key in kwargs
|
||||
}
|
||||
model_files = get_list_of_files(pretrained_model_name_or_path, **get_list_of_files_kwargs)
|
||||
# strip to file name
|
||||
model_files = [f.split("/")[-1] for f in model_files]
|
||||
|
||||
# Let's start by checking whether the processor class is saved in a feature extractor
|
||||
if FEATURE_EXTRACTOR_NAME in model_files:
|
||||
preprocessor_config_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **get_file_from_repo_kwargs
|
||||
)
|
||||
if preprocessor_config_file is not None:
|
||||
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
if "processor_class" in config_dict:
|
||||
processor_class = processor_class_from_name(config_dict["processor_class"])
|
||||
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# Next, let's check whether the processor class is saved in a tokenizer
|
||||
if TOKENIZER_CONFIG_FILE in model_files:
|
||||
config_dict = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
||||
# Let's start by checking whether the processor class is saved in a feature extractor
|
||||
tokenizer_config_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **get_file_from_repo_kwargs
|
||||
)
|
||||
if tokenizer_config_file is not None:
|
||||
with open(tokenizer_config_file, encoding="utf-8") as reader:
|
||||
config_dict = json.load(reader)
|
||||
|
||||
if "processor_class" in config_dict:
|
||||
processor_class = processor_class_from_name(config_dict["processor_class"])
|
||||
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
@@ -21,15 +21,7 @@ from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import (
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_sentencepiece_available,
|
||||
is_tokenizers_available,
|
||||
)
|
||||
from ...file_utils import get_file_from_repo, is_sentencepiece_available, is_tokenizers_available
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
@@ -329,46 +321,18 @@ def get_tokenizer_config(
|
||||
tokenizer.save_pretrained("tokenizer-test")
|
||||
tokenizer_config = get_tokenizer_config("tokenizer-test")
|
||||
```"""
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
|
||||
else:
|
||||
config_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=TOKENIZER_CONFIG_FILE, revision=revision, mirror=None
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_config_file = cached_path(
|
||||
config_file,
|
||||
resolved_config_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path,
|
||||
TOKENIZER_CONFIG_FILE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
proxies=proxies,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
|
||||
"pass a token having permission to this repo with `use_auth_token` or log in with "
|
||||
"`huggingface-cli login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
||||
"for this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EnvironmentError:
|
||||
if resolved_config_file is None:
|
||||
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
|
||||
return {}
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ from .file_utils import (
|
||||
add_end_docstrings,
|
||||
cached_path,
|
||||
copy_func,
|
||||
get_list_of_files,
|
||||
get_file_from_repo,
|
||||
hf_bucket_url,
|
||||
is_flax_available,
|
||||
is_offline_mode,
|
||||
@@ -1649,12 +1649,26 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
vocab_files[file_id] = pretrained_model_name_or_path
|
||||
else:
|
||||
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
|
||||
fast_tokenizer_file = get_fast_tokenizer_file(
|
||||
|
||||
# Try to get the tokenizer config to see if there are versioned tokenizer files.
|
||||
fast_tokenizer_file = FULL_TOKENIZER_FILE
|
||||
resolved_config_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
TOKENIZER_CONFIG_FILE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
if resolved_config_file is not None:
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
tokenizer_config = json.load(reader)
|
||||
if "fast_tokenizer_files" in tokenizer_config:
|
||||
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
|
||||
|
||||
additional_files_names = {
|
||||
"added_tokens_file": ADDED_TOKENS_FILE,
|
||||
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
|
||||
@@ -3495,41 +3509,18 @@ For a more complete example, see the implementation of `prepare_seq2seq_batch`.
|
||||
return model_inputs
|
||||
|
||||
|
||||
def get_fast_tokenizer_file(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
revision: Optional[str] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
local_files_only: bool = False,
|
||||
) -> str:
|
||||
def get_fast_tokenizer_file(tokenization_files: List[str]) -> str:
|
||||
"""
|
||||
Get the tokenizer file to use for this version of transformers.
|
||||
Get the tokenization file to use for this version of transformers.
|
||||
|
||||
Args:
|
||||
path_or_repo (`str` or `os.PathLike`):
|
||||
Can be either the id of a repo on huggingface.co or a path to a *directory*.
|
||||
revision(`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only rely on local files and not to attempt to download any files.
|
||||
tokenization_files (`List[str]`): The list of available configuration files.
|
||||
|
||||
Returns:
|
||||
`str`: The tokenizer file to use.
|
||||
`str`: The tokenization file to use.
|
||||
"""
|
||||
# Inspect all files from the repo/folder.
|
||||
try:
|
||||
all_files = get_list_of_files(
|
||||
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
|
||||
)
|
||||
except Exception:
|
||||
return FULL_TOKENIZER_FILE
|
||||
|
||||
tokenizer_files_map = {}
|
||||
for file_name in all_files:
|
||||
for file_name in tokenization_files:
|
||||
search = _re_tokenizer_file.search(file_name)
|
||||
if search is not None:
|
||||
v = search.groups()[0]
|
||||
|
||||
@@ -313,6 +313,7 @@ class ConfigTestUtils(unittest.TestCase):
|
||||
class ConfigurationVersioningTest(unittest.TestCase):
|
||||
def test_local_versioning(self):
|
||||
configuration = AutoConfig.from_pretrained("bert-base-cased")
|
||||
configuration.configuration_files = ["config.4.0.0.json"]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
configuration.save_pretrained(tmp_dir)
|
||||
@@ -325,23 +326,26 @@ class ConfigurationVersioningTest(unittest.TestCase):
|
||||
|
||||
# Will need to be adjusted if we reach v42 and this test is still here.
|
||||
# Should pick the old configuration file as the version of Transformers is < 4.42.0
|
||||
configuration.configuration_files = ["config.42.0.0.json"]
|
||||
configuration.hidden_size = 768
|
||||
configuration.save_pretrained(tmp_dir)
|
||||
shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json"))
|
||||
new_configuration = AutoConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_configuration.hidden_size, 768)
|
||||
|
||||
def test_repo_versioning_before(self):
|
||||
# This repo has two configuration files, one for v5.0.0 and above with an added token, one for versions lower.
|
||||
repo = "microsoft/layoutxlm-base"
|
||||
# This repo has two configuration files, one for v4.0.0 and above with a different hidden size.
|
||||
repo = "hf-internal-testing/test-two-configs"
|
||||
|
||||
import transformers as new_transformers
|
||||
|
||||
new_transformers.configuration_utils.__version__ = "v5.0.0"
|
||||
new_transformers.configuration_utils.__version__ = "v4.0.0"
|
||||
new_configuration = new_transformers.models.auto.AutoConfig.from_pretrained(repo)
|
||||
self.assertEqual(new_configuration.tokenizer_class, None)
|
||||
self.assertEqual(new_configuration.hidden_size, 2)
|
||||
|
||||
# Testing an older version by monkey-patching the version in the module it's used.
|
||||
import transformers as old_transformers
|
||||
|
||||
old_transformers.configuration_utils.__version__ = "v3.0.0"
|
||||
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
|
||||
self.assertEqual(old_configuration.tokenizer_class, "XLMRobertaTokenizer")
|
||||
self.assertEqual(old_configuration.hidden_size, 768)
|
||||
|
||||
@@ -15,7 +15,10 @@
|
||||
import contextlib
|
||||
import importlib
|
||||
import io
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import transformers
|
||||
|
||||
@@ -31,6 +34,7 @@ from transformers.file_utils import (
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
filename_to_url,
|
||||
get_file_from_repo,
|
||||
get_from_cache,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
@@ -128,6 +132,31 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME))
|
||||
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", FLAX_WEIGHTS_NAME))
|
||||
|
||||
def test_get_file_from_repo_distant(self):
|
||||
# `get_file_from_repo` returns None if the file does not exist
|
||||
self.assertIsNone(get_file_from_repo("bert-base-cased", "ahah.txt"))
|
||||
|
||||
# The function raises if the repository does not exist.
|
||||
with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
|
||||
get_file_from_repo("bert-base-case", "config.json")
|
||||
|
||||
# The function raises if the revision does not exist.
|
||||
with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
|
||||
get_file_from_repo("bert-base-cased", "config.json", revision="ahaha")
|
||||
|
||||
resolved_file = get_file_from_repo("bert-base-cased", "config.json")
|
||||
# The name is the cached name which is not very easy to test, so instead we load the content.
|
||||
config = json.loads(open(resolved_file, "r").read())
|
||||
self.assertEqual(config["hidden_size"], 768)
|
||||
|
||||
def test_get_file_from_repo_local(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
filename = Path(tmp_dir) / "a.txt"
|
||||
filename.touch()
|
||||
self.assertEqual(get_file_from_repo(tmp_dir, "a.txt"), str(filename))
|
||||
|
||||
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
|
||||
|
||||
|
||||
class ContextManagerTests(unittest.TestCase):
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
|
||||
@@ -108,6 +108,8 @@ class TokenizerVersioningTest(unittest.TestCase):
|
||||
json_tokenizer["model"]["vocab"]["huggingface"] = len(tokenizer)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Hack to save this in the tokenizer_config.json
|
||||
tokenizer.init_kwargs["fast_tokenizer_files"] = ["tokenizer.4.0.0.json"]
|
||||
tokenizer.save_pretrained(tmp_dir)
|
||||
json.dump(json_tokenizer, open(os.path.join(tmp_dir, "tokenizer.4.0.0.json"), "w"))
|
||||
|
||||
@@ -120,6 +122,8 @@ class TokenizerVersioningTest(unittest.TestCase):
|
||||
# Will need to be adjusted if we reach v42 and this test is still here.
|
||||
# Should pick the old tokenizer file as the version of Transformers is < 4.0.0
|
||||
shutil.move(os.path.join(tmp_dir, "tokenizer.4.0.0.json"), os.path.join(tmp_dir, "tokenizer.42.0.0.json"))
|
||||
tokenizer.init_kwargs["fast_tokenizer_files"] = ["tokenizer.42.0.0.json"]
|
||||
tokenizer.save_pretrained(tmp_dir)
|
||||
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
|
||||
self.assertEqual(len(new_tokenizer), len(tokenizer))
|
||||
json_tokenizer = json.loads(new_tokenizer._tokenizer.to_str())
|
||||
@@ -127,7 +131,7 @@ class TokenizerVersioningTest(unittest.TestCase):
|
||||
|
||||
def test_repo_versioning(self):
|
||||
# This repo has two tokenizer files, one for v4.0.0 and above with an added token, one for versions lower.
|
||||
repo = "sgugger/finetuned-bert-mrpc"
|
||||
repo = "hf-internal-testing/test-two-tokenizers"
|
||||
|
||||
# This should pick the new tokenizer file as the version of Transformers is > 4.0.0
|
||||
tokenizer = AutoTokenizer.from_pretrained(repo)
|
||||
|
||||
Reference in New Issue
Block a user