Allow per-version configurations (#14344)
* Allow per-version configurations * Update tests/test_configuration_common.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update tests/test_configuration_common.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -16,8 +16,10 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
from huggingface_hub import Repository, delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
@@ -306,3 +308,40 @@ class ConfigTestUtils(unittest.TestCase):
|
||||
"The following keys are set with the default values in `test_configuration_common.config_common_kwargs` "
|
||||
f"pick another value for them: {', '.join(keys_with_defaults)}."
|
||||
)
|
||||
|
||||
|
||||
class ConfigurationVersioningTest(unittest.TestCase):
|
||||
def test_local_versioning(self):
|
||||
configuration = AutoConfig.from_pretrained("bert-base-cased")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
configuration.save_pretrained(tmp_dir)
|
||||
configuration.hidden_size = 2
|
||||
json.dump(configuration.to_dict(), open(os.path.join(tmp_dir, "config.4.0.0.json"), "w"))
|
||||
|
||||
# This should pick the new configuration file as the version of Transformers is > 4.0.0
|
||||
new_configuration = AutoConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_configuration.hidden_size, 2)
|
||||
|
||||
# Will need to be adjusted if we reach v42 and this test is still here.
|
||||
# Should pick the old configuration file as the version of Transformers is < 4.42.0
|
||||
shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json"))
|
||||
new_configuration = AutoConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_configuration.hidden_size, 768)
|
||||
|
||||
def test_repo_versioning_before(self):
|
||||
# This repo has two configuration files, one for v5.0.0 and above with an added token, one for versions lower.
|
||||
repo = "microsoft/layoutxlm-base"
|
||||
|
||||
import transformers as new_transformers
|
||||
|
||||
new_transformers.configuration_utils.__version__ = "v5.0.0"
|
||||
new_configuration = new_transformers.models.auto.AutoConfig.from_pretrained(repo)
|
||||
self.assertEqual(new_configuration.tokenizer_class, None)
|
||||
|
||||
# Testing an older version by monkey-patching the version in the module it's used.
|
||||
import transformers as old_transformers
|
||||
|
||||
old_transformers.configuration_utils.__version__ = "v3.0.0"
|
||||
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
|
||||
self.assertEqual(old_configuration.tokenizer_class, "XLMRobertaTokenizer")
|
||||
|
||||
Reference in New Issue
Block a user