From ed6ceb7649c548998d884cb8166fc6179f791584 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 21 Feb 2023 09:38:26 +0100 Subject: [PATCH] Fix typo in `PROCESSOR_MAPPING_NAMES` and add tests (#21703) * Add test * Fix GITProcessor * Update --------- Co-authored-by: ydshieh --- .../models/auto/processing_auto.py | 2 +- utils/check_repo.py | 43 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index bac6fe78c1..5405df3f7f 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -50,7 +50,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ("clip", "CLIPProcessor"), ("clipseg", "CLIPSegProcessor"), ("flava", "FlavaProcessor"), - ("git", "GITProcessor"), + ("git", "GitProcessor"), ("groupvit", "CLIPProcessor"), ("hubert", "Wav2Vec2Processor"), ("layoutlmv2", "LayoutLMv2Processor"), diff --git a/utils/check_repo.py b/utils/check_repo.py index c061b1fdc1..53717645cf 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -23,6 +23,13 @@ from pathlib import Path from transformers import is_flax_available, is_tf_available, is_torch_available from transformers.models.auto import get_values +from transformers.models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES +from transformers.models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES +from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES +from transformers.models.auto.modeling_flax_auto import FLAX_MODEL_MAPPING_NAMES +from transformers.models.auto.modeling_tf_auto import TF_MODEL_MAPPING_NAMES +from transformers.models.auto.processing_auto import PROCESSOR_MAPPING_NAMES +from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES from transformers.utils import ENV_VARS_TRUE_VALUES, direct_transformers_import @@ -602,6 +609,40 @@ def check_all_models_are_auto_configured(): raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) +def check_all_auto_object_names_being_defined(): + """Check all names defined in auto (name) mappings exist in the library.""" + failures = [] + + mapping_to_check = { + "TOKENIZER_MAPPING_NAMES": TOKENIZER_MAPPING_NAMES, + "IMAGE_PROCESSOR_MAPPING_NAMES": IMAGE_PROCESSOR_MAPPING_NAMES, + "FEATURE_EXTRACTOR_MAPPING_NAMES": FEATURE_EXTRACTOR_MAPPING_NAMES, + "PROCESSOR_MAPPING_NAMES": PROCESSOR_MAPPING_NAMES, + "MODEL_MAPPING_NAMES": MODEL_MAPPING_NAMES, + "TF_MODEL_MAPPING_NAMES": TF_MODEL_MAPPING_NAMES, + "FLAX_MODEL_MAPPING_NAMES": FLAX_MODEL_MAPPING_NAMES, + } + + for name, mapping in mapping_to_check.items(): + for model_type, class_names in mapping.items(): + if not isinstance(class_names, tuple): + class_names = (class_names,) + for class_name in class_names: + if class_name is None: + continue + # dummy object is accepted + if not hasattr(transformers, class_name): + # If the class name is in a model name mapping, let's not check if there is a definition in any modeling + # module, if it's a private model defined in this file. + if name.endswith("MODEL_MAPPING_NAMES") and is_a_private_model(class_name): + continue + failures.append( + f"`{class_name}` appears in the mapping `{name}` but it is not defined in the library." + ) + if len(failures) > 0: + raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) + + _re_decorator = re.compile(r"^\s*@(\S+)\s+$") @@ -876,6 +917,8 @@ def check_repo_quality(): check_all_objects_are_documented() print("Checking all models are in at least one auto class.") check_all_models_are_auto_configured() + print("Checking all names in auto name mappings are defined.") + check_all_auto_object_names_being_defined() if __name__ == "__main__":