Fix check_config_attributes: check all configuration classes (#24231)

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-06-15 11:39:20 +02:00
committed by GitHub
parent 6793f0cfe0
commit 7504be35ab
8 changed files with 26 additions and 75 deletions

View File

@@ -17,6 +17,7 @@ import inspect
import os
import re
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import direct_transformers_import
@@ -77,6 +78,12 @@ SPECIAL_CASES_TO_ALLOW = {
"TimeSeriesTransformerConfig": ["num_static_real_features", "num_time_features"],
# used internally to calculate the feature size
"AutoformerConfig": ["num_static_real_features", "num_time_features"],
# used internally to calculate `mlp_dim`
"SamVisionConfig": ["mlp_ratio"],
# For (head) training, but so far not implemented
"ClapAudioConfig": ["num_classes"],
# Not used, but providing useful information to users
"SpeechT5HifiGanConfig": ["sampling_rate"],
}
@@ -113,6 +120,10 @@ SPECIAL_CASES_TO_ALLOW.update(
"VanConfig": True,
"WavLMConfig": True,
"WhisperConfig": True,
# TODO: @Arthur (for `alignment_head` and `alignment_layer`)
"JukeboxPriorConfig": True,
# TODO: @Younes (for `is_decoder`)
"Pix2StructTextConfig": True,
}
)
@@ -254,10 +265,21 @@ def check_config_attributes_being_used(config_class):
def check_config_attributes():
"""Check the arguments in `__init__` of all configuration classes are used in python files"""
configs_with_unused_attributes = {}
for config_class in list(CONFIG_MAPPING.values()):
unused_attributes = check_config_attributes_being_used(config_class)
if len(unused_attributes) > 0:
configs_with_unused_attributes[config_class.__name__] = unused_attributes
for _config_class in list(CONFIG_MAPPING.values()):
# Some config classes are not in `CONFIG_MAPPING` (e.g. `CLIPVisionConfig`, `Blip2VisionConfig`, etc.)
config_classes_in_module = [
cls
for name, cls in inspect.getmembers(
inspect.getmodule(_config_class),
lambda x: inspect.isclass(x)
and issubclass(x, PretrainedConfig)
and inspect.getmodule(x) == inspect.getmodule(_config_class),
)
]
for config_class in config_classes_in_module:
unused_attributes = check_config_attributes_being_used(config_class)
if len(unused_attributes) > 0:
configs_with_unused_attributes[config_class.__name__] = unused_attributes
if len(configs_with_unused_attributes) > 0:
error = "The following configuration classes contain unused attributes in the corresponding modeling files:\n"