Fix _configuration_file argument getting passed to model (#15629)
This commit is contained in:
@@ -580,7 +580,7 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
config_file = pretrained_model_name_or_path
|
config_file = pretrained_model_name_or_path
|
||||||
else:
|
else:
|
||||||
configuration_file = kwargs.get("_configuration_file", CONFIG_NAME)
|
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
|
||||||
|
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
|
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
|
||||||
|
|||||||
@@ -334,8 +334,12 @@ class ConfigurationVersioningTest(unittest.TestCase):
|
|||||||
import transformers as new_transformers
|
import transformers as new_transformers
|
||||||
|
|
||||||
new_transformers.configuration_utils.__version__ = "v4.0.0"
|
new_transformers.configuration_utils.__version__ = "v4.0.0"
|
||||||
new_configuration = new_transformers.models.auto.AutoConfig.from_pretrained(repo)
|
new_configuration, kwargs = new_transformers.models.auto.AutoConfig.from_pretrained(
|
||||||
|
repo, return_unused_kwargs=True
|
||||||
|
)
|
||||||
self.assertEqual(new_configuration.hidden_size, 2)
|
self.assertEqual(new_configuration.hidden_size, 2)
|
||||||
|
# This checks `_configuration_file` ia not kept in the kwargs by mistake.
|
||||||
|
self.assertDictEqual(kwargs, {"_from_auto": True})
|
||||||
|
|
||||||
# Testing an older version by monkey-patching the version in the module it's used.
|
# Testing an older version by monkey-patching the version in the module it's used.
|
||||||
import transformers as old_transformers
|
import transformers as old_transformers
|
||||||
|
|||||||
Reference in New Issue
Block a user