Add all XxxPreTrainedModel to the main init (#12314)

* Add all XxxPreTrainedModel to the main init

* Add to template

* Add to template bis

* Add FlaxT5
This commit is contained in:
Sylvain Gugger
2021-06-23 10:40:54 -04:00
committed by GitHub
parent 53c60babe4
commit 9eda6b52e2
26 changed files with 532 additions and 51 deletions

View File

@@ -31,9 +31,16 @@ PATH_TO_TRANSFORMERS = "src/transformers"
PATH_TO_TESTS = "tests"
PATH_TO_DOC = "docs/source"
# Update this list with models that are supposed to be private.
PRIVATE_MODELS = [
"DPRSpanPredictor",
"T5Stack",
"TFDPRSpanPredictor",
]
# Update this list for models that are not tested with a comment explaining the reason it should not be.
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = [
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
@@ -63,12 +70,9 @@ IGNORE_NON_TESTED = [
"PegasusEncoder", # Building part of bigger (tested) model.
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
"DPREncoder", # Building part of bigger (tested) model.
"DPRSpanPredictor", # Building part of bigger (tested) model.
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
"ReformerForMaskedLM", # Needs to be setup as decoder.
"T5Stack", # Building part of bigger (tested) model.
"TFDPREncoder", # Building part of bigger (tested) model.
"TFDPRSpanPredictor", # Building part of bigger (tested) model.
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
"TFRobertaForMultipleChoice", # TODO: fix
"SeparableConv1D", # Building part of bigger (tested) model.
@@ -92,7 +96,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = [
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
# models to ignore for model xxx mapping
"CLIPTextModel",
"CLIPVisionModel",
@@ -100,7 +104,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
"FlaxCLIPVisionModel",
"DetrForSegmentation",
"DPRReader",
"DPRSpanPredictor",
"FlaubertForQuestionAnswering",
"GPT2DoubleHeadsModel",
"LukeForEntityClassification",
@@ -110,9 +113,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
"RagModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
"T5Stack",
"TFDPRReader",
"TFDPRSpanPredictor",
"TFGPT2DoubleHeadsModel",
"TFOpenAIGPTDoubleHeadsModel",
"TFRagModel",
@@ -173,12 +174,12 @@ def get_model_modules():
return modules
def get_models(module):
def get_models(module, include_pretrained=False):
"""Get the objects in module that are models."""
models = []
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
for attr_name in dir(module):
if "Pretrained" in attr_name or "PreTrained" in attr_name:
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
continue
attr = getattr(module, attr_name)
if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__:
@@ -186,6 +187,36 @@ def get_models(module):
return models
def is_a_private_model(model):
"""Returns True if the model should not be in the main init."""
if model in PRIVATE_MODELS:
return True
# Wrapper, Encoder and Decoder are all privates
if model.endswith("Wrapper"):
return True
if model.endswith("Encoder"):
return True
if model.endswith("Decoder"):
return True
return False
def check_models_are_in_init():
"""Checks all models defined in the library are in the main init."""
models_not_in_init = []
dir_transformers = dir(transformers)
for module in get_model_modules():
models_not_in_init += [
model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers
]
# Remove private models
models_not_in_init = [model for model in models_not_in_init if not is_a_private_model(model)]
if len(models_not_in_init) > 0:
raise Exception(f"The following models should be in the main init: {','.join(models_not_in_init)}.")
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
# nested list _ignore_files of this function.
def get_model_test_files():
@@ -229,6 +260,7 @@ def find_tested_models(test_file):
def check_models_are_tested(module, test_file):
"""Check models defined in module are tested in test_file."""
# XxxPreTrainedModel are not tested
defined_models = get_models(module)
tested_models = find_tested_models(test_file)
if tested_models is None:
@@ -515,6 +547,8 @@ def check_all_objects_are_documented():
def check_repo_quality():
"""Check all models are properly tested and documented."""
print("Checking all models are public.")
check_models_are_in_init()
print("Checking all models are properly tested.")
check_all_decorator_order()
check_all_models_are_tested()