Check all objects are equally in the main __init__ file (#24573)

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-06-29 17:49:59 +02:00
committed by GitHub
parent 8c4471d1fc
commit c817bc44e2
6 changed files with 106 additions and 2 deletions

View File

@@ -16,6 +16,7 @@
import inspect
import os
import re
import sys
import warnings
from collections import OrderedDict
from difflib import get_close_matches
@@ -336,6 +337,21 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"MusicgenForConditionalGeneration",
]
# DO NOT edit this list!
# (The corresponding pytorch objects should never be in the main `__init__`, but it's too late to remove)
OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK = [
"FlaxBertLayer",
"FlaxBigBirdLayer",
"FlaxRoFormerLayer",
"TFBertLayer",
"TFLxmertEncoder",
"TFLxmertXLayer",
"TFMPNetLayer",
"TFMobileBertLayer",
"TFSegformerLayer",
"TFViTMAELayer",
]
# Update this list for models that have multiple model types for the same
# model doc
MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
@@ -735,7 +751,48 @@ def check_all_auto_mappings_importable():
for name, _ in mappings_to_check.items():
name = name.replace("_MAPPING_NAMES", "_MAPPING")
if not hasattr(transformers, name):
failures.append(f"`{name}` should be defined in the main `__init__` file.")
failures.append(f"`{name}`")
if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
def check_objects_being_equally_in_main_init():
"""Check if an object is in the main __init__ if its counterpart in PyTorch is."""
attrs = dir(transformers)
failures = []
for attr in attrs:
obj = getattr(transformers, attr)
if hasattr(obj, "__module__"):
module_path = obj.__module__
module_name = module_path.split(".")[-1]
module_dir = ".".join(module_path.split(".")[:-1])
if (
module_name.startswith("modeling_")
and not module_name.startswith("modeling_tf_")
and not module_name.startswith("modeling_flax_")
):
parent_module = sys.modules[module_dir]
frameworks = []
if is_tf_available():
frameworks.append("TF")
if is_flax_available():
frameworks.append("Flax")
for framework in frameworks:
other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_")
if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"):
other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_")
other_module = getattr(parent_module, other_module_name)
if hasattr(other_module, f"{framework}{attr}"):
if not hasattr(transformers, f"{framework}{attr}"):
if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
failures.append(f"{framework}{attr}")
if hasattr(other_module, f"{framework}_{attr}"):
if not hasattr(transformers, f"{framework}_{attr}"):
if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
failures.append(f"{framework}_{attr}")
if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
@@ -1024,6 +1081,8 @@ def check_repo_quality():
check_all_auto_mapping_names_in_config_mapping_names()
print("Checking all auto mappings could be imported.")
check_all_auto_mappings_importable()
print("Checking all objects are equally (across frameworks) in the main __init__.")
check_objects_being_equally_in_main_init()
if __name__ == "__main__":