Clean up auto mapping names (#21903)
* add new test * fix after new test --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -23,6 +23,7 @@ from pathlib import Path
|
||||
|
||||
from transformers import is_flax_available, is_tf_available, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from transformers.models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
|
||||
from transformers.models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES
|
||||
from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES
|
||||
@@ -646,6 +647,31 @@ def check_all_auto_object_names_being_defined():
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
|
||||
|
||||
def check_all_auto_mapping_names_in_config_mapping_names():
|
||||
"""Check all keys defined in auto mappings (mappings of names) appear in `CONFIG_MAPPING_NAMES`."""
|
||||
failures = []
|
||||
|
||||
# `TOKENIZER_PROCESSOR_MAPPING_NAMES` and `AutoTokenizer` is special, and don't need to follow the rule.
|
||||
mapping_to_check = {
|
||||
"IMAGE_PROCESSOR_MAPPING_NAMES": IMAGE_PROCESSOR_MAPPING_NAMES,
|
||||
"FEATURE_EXTRACTOR_MAPPING_NAMES": FEATURE_EXTRACTOR_MAPPING_NAMES,
|
||||
"PROCESSOR_MAPPING_NAMES": PROCESSOR_MAPPING_NAMES,
|
||||
"MODEL_MAPPING_NAMES": MODEL_MAPPING_NAMES,
|
||||
"TF_MODEL_MAPPING_NAMES": TF_MODEL_MAPPING_NAMES,
|
||||
"FLAX_MODEL_MAPPING_NAMES": FLAX_MODEL_MAPPING_NAMES,
|
||||
}
|
||||
|
||||
for name, mapping in mapping_to_check.items():
|
||||
for model_type, class_names in mapping.items():
|
||||
if model_type not in CONFIG_MAPPING_NAMES:
|
||||
failures.append(
|
||||
f"`{model_type}` appears in the mapping `{name}` but it is not defined in the keys of "
|
||||
"`CONFIG_MAPPING_NAMES`."
|
||||
)
|
||||
if len(failures) > 0:
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
|
||||
|
||||
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
|
||||
|
||||
|
||||
@@ -922,6 +948,8 @@ def check_repo_quality():
|
||||
check_all_models_are_auto_configured()
|
||||
print("Checking all names in auto name mappings are defined.")
|
||||
check_all_auto_object_names_being_defined()
|
||||
print("Checking all keys in auto name mappings are defined in `CONFIG_MAPPING_NAMES`.")
|
||||
check_all_auto_mapping_names_in_config_mapping_names()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user