Check all models are in an auto class (#8425)

This commit is contained in:
Sylvain Gugger
2020-11-09 15:44:54 -05:00
committed by GitHub
parent ef032ddd1e
commit a39218b75b
2 changed files with 73 additions and 0 deletions

View File

@@ -70,6 +70,34 @@ MODEL_NAME_TO_DOC_FILE = {
"marian": "marian.rst",
}
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = [
"DPRContextEncoder",
"DPREncoder",
"DPRReader",
"DPRSpanPredictor",
"FlaubertForQuestionAnswering",
"FunnelBaseModel",
"GPT2DoubleHeadsModel",
"OpenAIGPTDoubleHeadsModel",
"ProphetNetDecoder",
"ProphetNetEncoder",
"RagModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
"T5Stack",
"TFBertForNextSentencePrediction",
"TFFunnelBaseModel",
"TFGPT2DoubleHeadsModel",
"TFMobileBertForNextSentencePrediction",
"TFOpenAIGPTDoubleHeadsModel",
"XLMForQuestionAnswering",
"XLMProphetNetDecoder",
"XLMProphetNetEncoder",
"XLNetForQuestionAnswering",
]
# This is to make sure the transformers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location(
"transformers",
@@ -282,6 +310,45 @@ def check_all_models_are_documented():
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
def get_all_auto_configured_models():
""" Return the list of all models in at least one auto class."""
result = set() # To avoid duplicates we concatenate all model classes in a set.
for attr_name in dir(transformers.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.modeling_auto, attr_name).values())
for attr_name in dir(transformers.modeling_tf_auto):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.modeling_tf_auto, attr_name).values())
return [cls.__name__ for cls in result]
def check_models_are_auto_configured(module, all_auto_models):
""" Check models defined in module are each in an auto class."""
defined_models = get_models(module)
failures = []
for model_name, _ in defined_models:
if model_name not in all_auto_models and model_name not in IGNORE_NON_AUTO_CONFIGURED:
failures.append(
f"{model_name} is defined in {module.__name__} but is not present in any of the auto mapping. "
"If that is intended behavior, add its name to `IGNORE_NON_AUTO_CONFIGURED` in the file "
"`utils/check_repo.py`."
)
return failures
def check_all_models_are_auto_configured():
""" Check all models are each in an auto class."""
modules = get_model_modules()
all_auto_models = get_all_auto_configured_models()
failures = []
for module in modules:
new_failures = check_models_are_auto_configured(module, all_auto_models)
if new_failures is not None:
failures += new_failures
if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
@@ -325,6 +392,8 @@ def check_repo_quality():
check_all_models_are_tested()
print("Checking all models are properly documented.")
check_all_models_are_documented()
print("Checking all models are in at least one auto class.")
check_all_models_are_auto_configured()
if __name__ == "__main__":