Fix check repo utils (#8600)
This commit is contained in:
@@ -49,6 +49,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||
"test_modeling_mt5.py",
|
||||
"test_modeling_pegasus.py",
|
||||
"test_modeling_tf_camembert.py",
|
||||
"test_modeling_tf_mt5.py",
|
||||
"test_modeling_tf_xlm_roberta.py",
|
||||
"test_modeling_xlm_prophetnet.py",
|
||||
"test_modeling_xlm_roberta.py",
|
||||
@@ -62,7 +63,6 @@ IGNORE_NON_DOCUMENTED = [
|
||||
"T5Stack", # Building part of bigger (tested) model.
|
||||
"TFDPREncoder", # Building part of bigger (documented) model.
|
||||
"TFDPRSpanPredictor", # Building part of bigger (documented) model.
|
||||
"TFElectraMainLayer", # Building part of bigger (documented) model (should it be a TFPreTrainedModel ?)
|
||||
]
|
||||
|
||||
# Update this dict with any special correspondance model name (used in modeling_xxx.py) to doc file.
|
||||
@@ -135,11 +135,15 @@ def get_model_modules():
|
||||
"modeling_tf_transfo_xl_utilities",
|
||||
]
|
||||
modules = []
|
||||
for attr_name in dir(transformers):
|
||||
if attr_name.startswith("modeling") and attr_name not in _ignore_modules:
|
||||
module = getattr(transformers, attr_name)
|
||||
if inspect.ismodule(module):
|
||||
modules.append(module)
|
||||
for model in dir(transformers.models):
|
||||
# There are some magic dunder attributes in the dir, we ignore them
|
||||
if not model.startswith("__"):
|
||||
model_module = getattr(transformers.models, model)
|
||||
for submodule in dir(model_module):
|
||||
if submodule.startswith("modeling") and submodule not in _ignore_modules:
|
||||
modeling_module = getattr(model_module, submodule)
|
||||
if inspect.ismodule(modeling_module):
|
||||
modules.append(modeling_module)
|
||||
return modules
|
||||
|
||||
|
||||
@@ -244,7 +248,7 @@ def check_all_models_are_tested():
|
||||
test_files = get_model_test_files()
|
||||
failures = []
|
||||
for module in modules:
|
||||
test_file = f"test_{module.__name__.split('.')[1]}.py"
|
||||
test_file = f"test_{module.__name__.split('.')[-1]}.py"
|
||||
if test_file not in test_files:
|
||||
failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.")
|
||||
new_failures = check_models_are_tested(module, test_file)
|
||||
@@ -279,9 +283,9 @@ def check_models_are_documented(module, doc_file):
|
||||
|
||||
def _get_model_name(module):
|
||||
""" Get the model name for the module defining it."""
|
||||
splits = module.__name__.split("_")
|
||||
module_name = module.__name__.split(".")[-1]
|
||||
splits = module_name.split("_")
|
||||
splits = splits[(2 if splits[1] in ["flax", "tf"] else 1) :]
|
||||
|
||||
return "_".join(splits)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user