Check all objects are equally in the main __init__ file (#24573)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from difflib import get_close_matches
|
||||
@@ -336,6 +337,21 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"MusicgenForConditionalGeneration",
|
||||
]
|
||||
|
||||
# DO NOT edit this list!
|
||||
# (The corresponding pytorch objects should never be in the main `__init__`, but it's too late to remove)
|
||||
OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK = [
|
||||
"FlaxBertLayer",
|
||||
"FlaxBigBirdLayer",
|
||||
"FlaxRoFormerLayer",
|
||||
"TFBertLayer",
|
||||
"TFLxmertEncoder",
|
||||
"TFLxmertXLayer",
|
||||
"TFMPNetLayer",
|
||||
"TFMobileBertLayer",
|
||||
"TFSegformerLayer",
|
||||
"TFViTMAELayer",
|
||||
]
|
||||
|
||||
# Update this list for models that have multiple model types for the same
|
||||
# model doc
|
||||
MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
|
||||
@@ -735,7 +751,48 @@ def check_all_auto_mappings_importable():
|
||||
for name, _ in mappings_to_check.items():
|
||||
name = name.replace("_MAPPING_NAMES", "_MAPPING")
|
||||
if not hasattr(transformers, name):
|
||||
failures.append(f"`{name}` should be defined in the main `__init__` file.")
|
||||
failures.append(f"`{name}`")
|
||||
if len(failures) > 0:
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
|
||||
|
||||
def check_objects_being_equally_in_main_init():
|
||||
"""Check if an object is in the main __init__ if its counterpart in PyTorch is."""
|
||||
attrs = dir(transformers)
|
||||
|
||||
failures = []
|
||||
for attr in attrs:
|
||||
obj = getattr(transformers, attr)
|
||||
if hasattr(obj, "__module__"):
|
||||
module_path = obj.__module__
|
||||
module_name = module_path.split(".")[-1]
|
||||
module_dir = ".".join(module_path.split(".")[:-1])
|
||||
if (
|
||||
module_name.startswith("modeling_")
|
||||
and not module_name.startswith("modeling_tf_")
|
||||
and not module_name.startswith("modeling_flax_")
|
||||
):
|
||||
parent_module = sys.modules[module_dir]
|
||||
|
||||
frameworks = []
|
||||
if is_tf_available():
|
||||
frameworks.append("TF")
|
||||
if is_flax_available():
|
||||
frameworks.append("Flax")
|
||||
|
||||
for framework in frameworks:
|
||||
other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_")
|
||||
if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"):
|
||||
other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_")
|
||||
other_module = getattr(parent_module, other_module_name)
|
||||
if hasattr(other_module, f"{framework}{attr}"):
|
||||
if not hasattr(transformers, f"{framework}{attr}"):
|
||||
if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
|
||||
failures.append(f"{framework}{attr}")
|
||||
if hasattr(other_module, f"{framework}_{attr}"):
|
||||
if not hasattr(transformers, f"{framework}_{attr}"):
|
||||
if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
|
||||
failures.append(f"{framework}_{attr}")
|
||||
if len(failures) > 0:
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
|
||||
@@ -1024,6 +1081,8 @@ def check_repo_quality():
|
||||
check_all_auto_mapping_names_in_config_mapping_names()
|
||||
print("Checking all auto mappings could be imported.")
|
||||
check_all_auto_mappings_importable()
|
||||
print("Checking all objects are equally (across frameworks) in the main __init__.")
|
||||
check_objects_being_equally_in_main_init()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user