Fix typo in PROCESSOR_MAPPING_NAMES and add tests (#21703)
* Add test * Fix GITProcessor * Update --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -23,6 +23,13 @@ from pathlib import Path
|
||||
|
||||
from transformers import is_flax_available, is_tf_available, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
|
||||
from transformers.models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES
|
||||
from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES
|
||||
from transformers.models.auto.modeling_flax_auto import FLAX_MODEL_MAPPING_NAMES
|
||||
from transformers.models.auto.modeling_tf_auto import TF_MODEL_MAPPING_NAMES
|
||||
from transformers.models.auto.processing_auto import PROCESSOR_MAPPING_NAMES
|
||||
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES
|
||||
from transformers.utils import ENV_VARS_TRUE_VALUES, direct_transformers_import
|
||||
|
||||
|
||||
@@ -602,6 +609,40 @@ def check_all_models_are_auto_configured():
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
|
||||
|
||||
def check_all_auto_object_names_being_defined():
|
||||
"""Check all names defined in auto (name) mappings exist in the library."""
|
||||
failures = []
|
||||
|
||||
mapping_to_check = {
|
||||
"TOKENIZER_MAPPING_NAMES": TOKENIZER_MAPPING_NAMES,
|
||||
"IMAGE_PROCESSOR_MAPPING_NAMES": IMAGE_PROCESSOR_MAPPING_NAMES,
|
||||
"FEATURE_EXTRACTOR_MAPPING_NAMES": FEATURE_EXTRACTOR_MAPPING_NAMES,
|
||||
"PROCESSOR_MAPPING_NAMES": PROCESSOR_MAPPING_NAMES,
|
||||
"MODEL_MAPPING_NAMES": MODEL_MAPPING_NAMES,
|
||||
"TF_MODEL_MAPPING_NAMES": TF_MODEL_MAPPING_NAMES,
|
||||
"FLAX_MODEL_MAPPING_NAMES": FLAX_MODEL_MAPPING_NAMES,
|
||||
}
|
||||
|
||||
for name, mapping in mapping_to_check.items():
|
||||
for model_type, class_names in mapping.items():
|
||||
if not isinstance(class_names, tuple):
|
||||
class_names = (class_names,)
|
||||
for class_name in class_names:
|
||||
if class_name is None:
|
||||
continue
|
||||
# dummy object is accepted
|
||||
if not hasattr(transformers, class_name):
|
||||
# If the class name is in a model name mapping, let's not check if there is a definition in any modeling
|
||||
# module, if it's a private model defined in this file.
|
||||
if name.endswith("MODEL_MAPPING_NAMES") and is_a_private_model(class_name):
|
||||
continue
|
||||
failures.append(
|
||||
f"`{class_name}` appears in the mapping `{name}` but it is not defined in the library."
|
||||
)
|
||||
if len(failures) > 0:
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
|
||||
|
||||
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
|
||||
|
||||
|
||||
@@ -876,6 +917,8 @@ def check_repo_quality():
|
||||
check_all_objects_are_documented()
|
||||
print("Checking all models are in at least one auto class.")
|
||||
check_all_models_are_auto_configured()
|
||||
print("Checking all names in auto name mappings are defined.")
|
||||
check_all_auto_object_names_being_defined()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user