Allow per-version configurations (#14344)
* Allow per-version configurations * Update tests/test_configuration_common.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update tests/test_configuration_common.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -19,8 +19,11 @@
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
@@ -28,6 +31,7 @@ from .file_utils import (
|
|||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
cached_path,
|
cached_path,
|
||||||
copy_func,
|
copy_func,
|
||||||
|
get_list_of_files,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
@@ -37,6 +41,8 @@ from .utils import logging
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
FULL_CONFIGURATION_FILE = "config.json"
|
||||||
|
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
||||||
|
|
||||||
|
|
||||||
class PretrainedConfig(PushToHubMixin):
|
class PretrainedConfig(PushToHubMixin):
|
||||||
@@ -536,13 +542,21 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
local_files_only = True
|
local_files_only = True
|
||||||
|
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
|
||||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
|
||||||
config_file = pretrained_model_name_or_path
|
config_file = pretrained_model_name_or_path
|
||||||
|
else:
|
||||||
|
configuration_file = get_configuration_file(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
revision=revision,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
|
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
|
||||||
else:
|
else:
|
||||||
config_file = hf_bucket_url(
|
config_file = hf_bucket_url(
|
||||||
pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
|
pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -796,6 +810,56 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
||||||
|
|
||||||
|
|
||||||
|
def get_configuration_file(
|
||||||
|
path_or_repo: Union[str, os.PathLike],
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
use_auth_token: Optional[Union[bool, str]] = None,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get the configuration file to use for this version of transformers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path_or_repo (:obj:`str` or :obj:`os.PathLike`):
|
||||||
|
Can be either the id of a repo on huggingface.co or a path to a `directory`.
|
||||||
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
|
identifier allowed by git.
|
||||||
|
use_auth_token (:obj:`str` or `bool`, `optional`):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
|
||||||
|
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
|
||||||
|
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to only rely on local files and not to attempt to download any files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`str`: The configuration file to use.
|
||||||
|
"""
|
||||||
|
# Inspect all files from the repo/folder.
|
||||||
|
all_files = get_list_of_files(
|
||||||
|
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
|
||||||
|
)
|
||||||
|
configuration_files_map = {}
|
||||||
|
for file_name in all_files:
|
||||||
|
search = _re_configuration_file.search(file_name)
|
||||||
|
if search is not None:
|
||||||
|
v = search.groups()[0]
|
||||||
|
configuration_files_map[v] = file_name
|
||||||
|
available_versions = sorted(configuration_files_map.keys())
|
||||||
|
|
||||||
|
# Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
|
||||||
|
configuration_file = FULL_CONFIGURATION_FILE
|
||||||
|
transformers_version = version.parse(__version__)
|
||||||
|
for v in available_versions:
|
||||||
|
if version.parse(v) <= transformers_version:
|
||||||
|
configuration_file = configuration_files_map[v]
|
||||||
|
else:
|
||||||
|
# No point going further since the versions are sorted.
|
||||||
|
break
|
||||||
|
|
||||||
|
return configuration_file
|
||||||
|
|
||||||
|
|
||||||
PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
|
PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
|
||||||
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
|
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
|
||||||
object="config", object_class="AutoConfig", object_files="configuration file"
|
object="config", object_class="AutoConfig", object_files="configuration file"
|
||||||
|
|||||||
@@ -16,8 +16,10 @@
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
import unittest.mock
|
||||||
|
|
||||||
from huggingface_hub import Repository, delete_repo, login
|
from huggingface_hub import Repository, delete_repo, login
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
@@ -306,3 +308,40 @@ class ConfigTestUtils(unittest.TestCase):
|
|||||||
"The following keys are set with the default values in `test_configuration_common.config_common_kwargs` "
|
"The following keys are set with the default values in `test_configuration_common.config_common_kwargs` "
|
||||||
f"pick another value for them: {', '.join(keys_with_defaults)}."
|
f"pick another value for them: {', '.join(keys_with_defaults)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigurationVersioningTest(unittest.TestCase):
|
||||||
|
def test_local_versioning(self):
|
||||||
|
configuration = AutoConfig.from_pretrained("bert-base-cased")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
configuration.save_pretrained(tmp_dir)
|
||||||
|
configuration.hidden_size = 2
|
||||||
|
json.dump(configuration.to_dict(), open(os.path.join(tmp_dir, "config.4.0.0.json"), "w"))
|
||||||
|
|
||||||
|
# This should pick the new configuration file as the version of Transformers is > 4.0.0
|
||||||
|
new_configuration = AutoConfig.from_pretrained(tmp_dir)
|
||||||
|
self.assertEqual(new_configuration.hidden_size, 2)
|
||||||
|
|
||||||
|
# Will need to be adjusted if we reach v42 and this test is still here.
|
||||||
|
# Should pick the old configuration file as the version of Transformers is < 4.42.0
|
||||||
|
shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json"))
|
||||||
|
new_configuration = AutoConfig.from_pretrained(tmp_dir)
|
||||||
|
self.assertEqual(new_configuration.hidden_size, 768)
|
||||||
|
|
||||||
|
def test_repo_versioning_before(self):
|
||||||
|
# This repo has two configuration files, one for v5.0.0 and above with an added token, one for versions lower.
|
||||||
|
repo = "microsoft/layoutxlm-base"
|
||||||
|
|
||||||
|
import transformers as new_transformers
|
||||||
|
|
||||||
|
new_transformers.configuration_utils.__version__ = "v5.0.0"
|
||||||
|
new_configuration = new_transformers.models.auto.AutoConfig.from_pretrained(repo)
|
||||||
|
self.assertEqual(new_configuration.tokenizer_class, None)
|
||||||
|
|
||||||
|
# Testing an older version by monkey-patching the version in the module it's used.
|
||||||
|
import transformers as old_transformers
|
||||||
|
|
||||||
|
old_transformers.configuration_utils.__version__ = "v3.0.0"
|
||||||
|
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
|
||||||
|
self.assertEqual(old_configuration.tokenizer_class, "XLMRobertaTokenizer")
|
||||||
|
|||||||
Reference in New Issue
Block a user