From c817bc44e2915593a292df9e7d4e2c0dfefb6620 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 29 Jun 2023 17:49:59 +0200 Subject: [PATCH] Check all objects are equally in the main `__init__` file (#24573) * fix --------- Co-authored-by: ydshieh --- docs/source/en/model_doc/auto.md | 12 ++++ src/transformers/__init__.py | 6 ++ src/transformers/models/auto/__init__.py | 6 ++ .../models/auto/modeling_tf_auto.py | 2 +- src/transformers/utils/dummy_tf_objects.py | 21 +++++++ utils/check_repo.py | 61 ++++++++++++++++++- 6 files changed, 106 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/auto.md b/docs/source/en/model_doc/auto.md index 07ff0b8d6d..f493e208ee 100644 --- a/docs/source/en/model_doc/auto.md +++ b/docs/source/en/model_doc/auto.md @@ -142,6 +142,10 @@ The following auto classes are available for the following natural language proc [[autodoc]] AutoModelForMaskGeneration +### TFAutoModelForMaskGeneration + +[[autodoc]] TFAutoModelForMaskGeneration + ### AutoModelForSeq2SeqLM [[autodoc]] AutoModelForSeq2SeqLM @@ -250,6 +254,10 @@ The following auto classes are available for the following computer vision tasks [[autodoc]] AutoModelForMaskedImageModeling +### TFAutoModelForMaskedImageModeling + +[[autodoc]] TFAutoModelForMaskedImageModeling + ### AutoModelForObjectDetection [[autodoc]] AutoModelForObjectDetection @@ -296,6 +304,10 @@ The following auto classes are available for the following audio tasks. ### AutoModelForAudioFrameClassification +[[autodoc]] TFAutoModelForAudioClassification + +### TFAutoModelForAudioFrameClassification + [[autodoc]] AutoModelForAudioFrameClassification ### AutoModelForCTC diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3747e1951f..99683306d6 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3025,10 +3025,13 @@ else: "TF_MODEL_MAPPING", "TF_MODEL_WITH_LM_HEAD_MAPPING", "TFAutoModel", + "TFAutoModelForAudioClassification", "TFAutoModelForCausalLM", "TFAutoModelForDocumentQuestionAnswering", "TFAutoModelForImageClassification", + "TFAutoModelForMaskedImageModeling", "TFAutoModelForMaskedLM", + "TFAutoModelForMaskGeneration", "TFAutoModelForMultipleChoice", "TFAutoModelForNextSentencePrediction", "TFAutoModelForPreTraining", @@ -6453,10 +6456,13 @@ if TYPE_CHECKING: TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING, TFAutoModel, + TFAutoModelForAudioClassification, TFAutoModelForCausalLM, TFAutoModelForDocumentQuestionAnswering, TFAutoModelForImageClassification, + TFAutoModelForMaskedImageModeling, TFAutoModelForMaskedLM, + TFAutoModelForMaskGeneration, TFAutoModelForMultipleChoice, TFAutoModelForNextSentencePrediction, TFAutoModelForPreTraining, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 26b6609c28..36286f2be0 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -140,9 +140,12 @@ else: "TF_MODEL_MAPPING", "TF_MODEL_WITH_LM_HEAD_MAPPING", "TFAutoModel", + "TFAutoModelForAudioClassification", "TFAutoModelForCausalLM", "TFAutoModelForImageClassification", + "TFAutoModelForMaskedImageModeling", "TFAutoModelForMaskedLM", + "TFAutoModelForMaskGeneration", "TFAutoModelForMultipleChoice", "TFAutoModelForNextSentencePrediction", "TFAutoModelForPreTraining", @@ -313,10 +316,13 @@ if TYPE_CHECKING: TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING, TFAutoModel, + TFAutoModelForAudioClassification, TFAutoModelForCausalLM, TFAutoModelForDocumentQuestionAnswering, TFAutoModelForImageClassification, + TFAutoModelForMaskedImageModeling, TFAutoModelForMaskedLM, + TFAutoModelForMaskGeneration, TFAutoModelForMultipleChoice, TFAutoModelForNextSentencePrediction, TFAutoModelForPreTraining, diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 10e33cbebb..ecf9b06da5 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -593,7 +593,7 @@ class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING -TF_AutoModelForSemanticSegmentation = auto_class_update( +TFAutoModelForSemanticSegmentation = auto_class_update( TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation" ) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index e556844028..46cde8ffbe 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -289,6 +289,13 @@ class TFAutoModel(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFAutoModelForAudioClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFAutoModelForCausalLM(metaclass=DummyObject): _backends = ["tf"] @@ -310,6 +317,13 @@ class TFAutoModelForImageClassification(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFAutoModelForMaskedImageModeling(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFAutoModelForMaskedLM(metaclass=DummyObject): _backends = ["tf"] @@ -317,6 +331,13 @@ class TFAutoModelForMaskedLM(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFAutoModelForMaskGeneration(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFAutoModelForMultipleChoice(metaclass=DummyObject): _backends = ["tf"] diff --git a/utils/check_repo.py b/utils/check_repo.py index 46407eb1a5..db947d834b 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -16,6 +16,7 @@ import inspect import os import re +import sys import warnings from collections import OrderedDict from difflib import get_close_matches @@ -336,6 +337,21 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ "MusicgenForConditionalGeneration", ] +# DO NOT edit this list! +# (The corresponding pytorch objects should never be in the main `__init__`, but it's too late to remove) +OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK = [ + "FlaxBertLayer", + "FlaxBigBirdLayer", + "FlaxRoFormerLayer", + "TFBertLayer", + "TFLxmertEncoder", + "TFLxmertXLayer", + "TFMPNetLayer", + "TFMobileBertLayer", + "TFSegformerLayer", + "TFViTMAELayer", +] + # Update this list for models that have multiple model types for the same # model doc MODEL_TYPE_TO_DOC_MAPPING = OrderedDict( @@ -735,7 +751,48 @@ def check_all_auto_mappings_importable(): for name, _ in mappings_to_check.items(): name = name.replace("_MAPPING_NAMES", "_MAPPING") if not hasattr(transformers, name): - failures.append(f"`{name}` should be defined in the main `__init__` file.") + failures.append(f"`{name}`") + if len(failures) > 0: + raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) + + +def check_objects_being_equally_in_main_init(): + """Check if an object is in the main __init__ if its counterpart in PyTorch is.""" + attrs = dir(transformers) + + failures = [] + for attr in attrs: + obj = getattr(transformers, attr) + if hasattr(obj, "__module__"): + module_path = obj.__module__ + module_name = module_path.split(".")[-1] + module_dir = ".".join(module_path.split(".")[:-1]) + if ( + module_name.startswith("modeling_") + and not module_name.startswith("modeling_tf_") + and not module_name.startswith("modeling_flax_") + ): + parent_module = sys.modules[module_dir] + + frameworks = [] + if is_tf_available(): + frameworks.append("TF") + if is_flax_available(): + frameworks.append("Flax") + + for framework in frameworks: + other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_") + if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"): + other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_") + other_module = getattr(parent_module, other_module_name) + if hasattr(other_module, f"{framework}{attr}"): + if not hasattr(transformers, f"{framework}{attr}"): + if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK: + failures.append(f"{framework}{attr}") + if hasattr(other_module, f"{framework}_{attr}"): + if not hasattr(transformers, f"{framework}_{attr}"): + if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK: + failures.append(f"{framework}_{attr}") if len(failures) > 0: raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) @@ -1024,6 +1081,8 @@ def check_repo_quality(): check_all_auto_mapping_names_in_config_mapping_names() print("Checking all auto mappings could be imported.") check_all_auto_mappings_importable() + print("Checking all objects are equally (across frameworks) in the main __init__.") + check_objects_being_equally_in_main_init() if __name__ == "__main__":