Clean up auto mapping names (#21903)

* add new test

* fix after new test

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-03-02 17:14:50 +01:00
committed by GitHub
parent 50a8ed3ee0
commit 99ba36e72f
3 changed files with 29 additions and 5 deletions

View File

@@ -23,6 +23,7 @@ from pathlib import Path
from transformers import is_flax_available, is_tf_available, is_torch_available
from transformers.models.auto import get_values
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
from transformers.models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
from transformers.models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES
from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES
@@ -646,6 +647,31 @@ def check_all_auto_object_names_being_defined():
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
def check_all_auto_mapping_names_in_config_mapping_names():
"""Check all keys defined in auto mappings (mappings of names) appear in `CONFIG_MAPPING_NAMES`."""
failures = []
# `TOKENIZER_PROCESSOR_MAPPING_NAMES` and `AutoTokenizer` is special, and don't need to follow the rule.
mapping_to_check = {
"IMAGE_PROCESSOR_MAPPING_NAMES": IMAGE_PROCESSOR_MAPPING_NAMES,
"FEATURE_EXTRACTOR_MAPPING_NAMES": FEATURE_EXTRACTOR_MAPPING_NAMES,
"PROCESSOR_MAPPING_NAMES": PROCESSOR_MAPPING_NAMES,
"MODEL_MAPPING_NAMES": MODEL_MAPPING_NAMES,
"TF_MODEL_MAPPING_NAMES": TF_MODEL_MAPPING_NAMES,
"FLAX_MODEL_MAPPING_NAMES": FLAX_MODEL_MAPPING_NAMES,
}
for name, mapping in mapping_to_check.items():
for model_type, class_names in mapping.items():
if model_type not in CONFIG_MAPPING_NAMES:
failures.append(
f"`{model_type}` appears in the mapping `{name}` but it is not defined in the keys of "
"`CONFIG_MAPPING_NAMES`."
)
if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
@@ -922,6 +948,8 @@ def check_repo_quality():
check_all_models_are_auto_configured()
print("Checking all names in auto name mappings are defined.")
check_all_auto_object_names_being_defined()
print("Checking all keys in auto name mappings are defined in `CONFIG_MAPPING_NAMES`.")
check_all_auto_mapping_names_in_config_mapping_names()
if __name__ == "__main__":