Doc checks (#25408)
* Document check_dummies * Type hints and doc in other files * Document check inits * Add documentation to * Address review comments
This commit is contained in:
@@ -12,15 +12,34 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utility that performs several consistency checks on the repo. This includes:
|
||||
- checking all models are properly defined in the __init__ of models/
|
||||
- checking all models are in the main __init__
|
||||
- checking all models are properly tested
|
||||
- checking all object in the main __init__ are documented
|
||||
- checking all models are in at least one auto class
|
||||
- checking all the auto mapping are properly defined (no typos, importable)
|
||||
- checking the list of deprecated models is up to date
|
||||
|
||||
Use from the root of the repo with (as used in `make repo-consistency`):
|
||||
|
||||
```bash
|
||||
python utils/check_repo.py
|
||||
```
|
||||
|
||||
It has no auto-fix mode.
|
||||
"""
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from difflib import get_close_matches
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from transformers import is_flax_available, is_tf_available, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
@@ -60,91 +79,25 @@ PRIVATE_MODELS = [
|
||||
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
# models to ignore for not tested
|
||||
"InstructBlipQFormerModel", # Building part of bigger (tested) model.
|
||||
"NllbMoeDecoder",
|
||||
"NllbMoeEncoder",
|
||||
"UMT5EncoderModel", # Building part of bigger (tested) model.
|
||||
"LlamaDecoder", # Building part of bigger (tested) model.
|
||||
"Blip2QFormerModel", # Building part of bigger (tested) model.
|
||||
"DetaEncoder", # Building part of bigger (tested) model.
|
||||
"DetaDecoder", # Building part of bigger (tested) model.
|
||||
"ErnieMForInformationExtraction",
|
||||
"GraphormerEncoder", # Building part of bigger (tested) model.
|
||||
"GraphormerDecoderHead", # Building part of bigger (tested) model.
|
||||
"CLIPSegDecoder", # Building part of bigger (tested) model.
|
||||
"TableTransformerEncoder", # Building part of bigger (tested) model.
|
||||
"TableTransformerDecoder", # Building part of bigger (tested) model.
|
||||
"TimeSeriesTransformerEncoder", # Building part of bigger (tested) model.
|
||||
"TimeSeriesTransformerDecoder", # Building part of bigger (tested) model.
|
||||
"InformerEncoder", # Building part of bigger (tested) model.
|
||||
"InformerDecoder", # Building part of bigger (tested) model.
|
||||
"AutoformerEncoder", # Building part of bigger (tested) model.
|
||||
"AutoformerDecoder", # Building part of bigger (tested) model.
|
||||
"JukeboxVQVAE", # Building part of bigger (tested) model.
|
||||
"JukeboxPrior", # Building part of bigger (tested) model.
|
||||
"DeformableDetrEncoder", # Building part of bigger (tested) model.
|
||||
"DeformableDetrDecoder", # Building part of bigger (tested) model.
|
||||
"OPTDecoder", # Building part of bigger (tested) model.
|
||||
"FlaxWhisperDecoder", # Building part of bigger (tested) model.
|
||||
"FlaxWhisperEncoder", # Building part of bigger (tested) model.
|
||||
"WhisperDecoder", # Building part of bigger (tested) model.
|
||||
"WhisperEncoder", # Building part of bigger (tested) model.
|
||||
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
|
||||
"SegformerDecodeHead", # Building part of bigger (tested) model.
|
||||
"PLBartEncoder", # Building part of bigger (tested) model.
|
||||
"PLBartDecoder", # Building part of bigger (tested) model.
|
||||
"PLBartDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
|
||||
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
|
||||
"BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"DetrEncoder", # Building part of bigger (tested) model.
|
||||
"DetrDecoder", # Building part of bigger (tested) model.
|
||||
"DetrDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"ConditionalDetrEncoder", # Building part of bigger (tested) model.
|
||||
"ConditionalDetrDecoder", # Building part of bigger (tested) model.
|
||||
"M2M100Encoder", # Building part of bigger (tested) model.
|
||||
"M2M100Decoder", # Building part of bigger (tested) model.
|
||||
"MCTCTEncoder", # Building part of bigger (tested) model.
|
||||
"MgpstrModel", # Building part of bigger (tested) model.
|
||||
"Speech2TextEncoder", # Building part of bigger (tested) model.
|
||||
"Speech2TextDecoder", # Building part of bigger (tested) model.
|
||||
"LEDEncoder", # Building part of bigger (tested) model.
|
||||
"LEDDecoder", # Building part of bigger (tested) model.
|
||||
"BartDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"BartEncoder", # Building part of bigger (tested) model.
|
||||
"BertLMHeadModel", # Needs to be setup as decoder.
|
||||
"BlenderbotSmallEncoder", # Building part of bigger (tested) model.
|
||||
"BlenderbotSmallDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"BlenderbotEncoder", # Building part of bigger (tested) model.
|
||||
"BlenderbotDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"MBartEncoder", # Building part of bigger (tested) model.
|
||||
"MBartDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"MegatronBertLMHeadModel", # Building part of bigger (tested) model.
|
||||
"MegatronBertEncoder", # Building part of bigger (tested) model.
|
||||
"MegatronBertDecoder", # Building part of bigger (tested) model.
|
||||
"MegatronBertDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"MusicgenDecoder", # Building part of bigger (tested) model.
|
||||
"MvpDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"MvpEncoder", # Building part of bigger (tested) model.
|
||||
"PegasusEncoder", # Building part of bigger (tested) model.
|
||||
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"PegasusXEncoder", # Building part of bigger (tested) model.
|
||||
"PegasusXDecoder", # Building part of bigger (tested) model.
|
||||
"PegasusXDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"DPREncoder", # Building part of bigger (tested) model.
|
||||
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"RealmBertModel", # Building part of bigger (tested) model.
|
||||
"RealmReader", # Not regular model.
|
||||
"RealmScorer", # Not regular model.
|
||||
"RealmForOpenQA", # Not regular model.
|
||||
"ReformerForMaskedLM", # Needs to be setup as decoder.
|
||||
"Speech2Text2DecoderWrapper", # Building part of bigger (tested) model.
|
||||
"TFDPREncoder", # Building part of bigger (tested) model.
|
||||
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
|
||||
"TFRobertaForMultipleChoice", # TODO: fix
|
||||
"TFRobertaPreLayerNormForMultipleChoice", # TODO: fix
|
||||
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"TFWhisperEncoder", # Building part of bigger (tested) model.
|
||||
"TFWhisperDecoder", # Building part of bigger (tested) model.
|
||||
"SeparableConv1D", # Building part of bigger (tested) model.
|
||||
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
|
||||
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
|
||||
@@ -155,18 +108,6 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
"TFBlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models
|
||||
"BridgeTowerTextModel", # No need to test it as it is tested by BridgeTowerModel model.
|
||||
"BridgeTowerVisionModel", # No need to test it as it is tested by BridgeTowerModel model.
|
||||
"SpeechT5Decoder", # Building part of bigger (tested) model.
|
||||
"SpeechT5DecoderWithoutPrenet", # Building part of bigger (tested) model.
|
||||
"SpeechT5DecoderWithSpeechPrenet", # Building part of bigger (tested) model.
|
||||
"SpeechT5DecoderWithTextPrenet", # Building part of bigger (tested) model.
|
||||
"SpeechT5Encoder", # Building part of bigger (tested) model.
|
||||
"SpeechT5EncoderWithoutPrenet", # Building part of bigger (tested) model.
|
||||
"SpeechT5EncoderWithSpeechPrenet", # Building part of bigger (tested) model.
|
||||
"SpeechT5EncoderWithTextPrenet", # Building part of bigger (tested) model.
|
||||
"SpeechT5SpeechDecoder", # Building part of bigger (tested) model.
|
||||
"SpeechT5SpeechEncoder", # Building part of bigger (tested) model.
|
||||
"SpeechT5TextDecoder", # Building part of bigger (tested) model.
|
||||
"SpeechT5TextEncoder", # Building part of bigger (tested) model.
|
||||
"BarkCausalModel", # Building part of bigger (tested) model.
|
||||
"BarkModel", # Does not have a forward signature - generation tested with integration tests
|
||||
]
|
||||
@@ -236,12 +177,6 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"AutoformerForPrediction",
|
||||
"JukeboxVQVAE",
|
||||
"JukeboxPrior",
|
||||
"PegasusXEncoder",
|
||||
"PegasusXDecoder",
|
||||
"PegasusXDecoderWrapper",
|
||||
"PegasusXEncoder",
|
||||
"PegasusXDecoder",
|
||||
"PegasusXDecoderWrapper",
|
||||
"SamModel",
|
||||
"DPTForDepthEstimation",
|
||||
"DecisionTransformerGPT2Model",
|
||||
@@ -250,17 +185,11 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"ViltForImageAndTextRetrieval",
|
||||
"ViltForTokenClassification",
|
||||
"ViltForMaskedLM",
|
||||
"XGLMEncoder",
|
||||
"XGLMDecoder",
|
||||
"XGLMDecoderWrapper",
|
||||
"PerceiverForMultimodalAutoencoding",
|
||||
"PerceiverForOpticalFlow",
|
||||
"SegformerDecodeHead",
|
||||
"TFSegformerDecodeHead",
|
||||
"FlaxBeitForMaskedImageModeling",
|
||||
"PLBartEncoder",
|
||||
"PLBartDecoder",
|
||||
"PLBartDecoderWrapper",
|
||||
"BeitForMaskedImageModeling",
|
||||
"ChineseCLIPTextModel",
|
||||
"ChineseCLIPVisionModel",
|
||||
@@ -347,7 +276,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
]
|
||||
|
||||
# DO NOT edit this list!
|
||||
# (The corresponding pytorch objects should never be in the main `__init__`, but it's too late to remove)
|
||||
# (The corresponding pytorch objects should never have been in the main `__init__`, but it's too late to remove)
|
||||
OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK = [
|
||||
"FlaxBertLayer",
|
||||
"FlaxBigBirdLayer",
|
||||
@@ -361,8 +290,7 @@ OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK = [
|
||||
"TFViTMAELayer",
|
||||
]
|
||||
|
||||
# Update this list for models that have multiple model types for the same
|
||||
# model doc
|
||||
# Update this list for models that have multiple model types for the same model doc.
|
||||
MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
|
||||
[
|
||||
("data2vec-text", "data2vec"),
|
||||
@@ -378,6 +306,10 @@ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
|
||||
|
||||
|
||||
def check_missing_backends():
|
||||
"""
|
||||
Checks if all backends are installed (otherwise the check of this script is incomplete). Will error in the CI if
|
||||
that's not the case but only throw a warning for users running this.
|
||||
"""
|
||||
missing_backends = []
|
||||
if not is_torch_available():
|
||||
missing_backends.append("PyTorch")
|
||||
@@ -402,7 +334,9 @@ def check_missing_backends():
|
||||
|
||||
|
||||
def check_model_list():
|
||||
"""Check the model list inside the transformers library."""
|
||||
"""
|
||||
Checks the model listed as subfolders of `models` match the models available in `transformers.models`.
|
||||
"""
|
||||
# Get the models from the directory structure of `src/transformers/models/`
|
||||
models_dir = os.path.join(PATH_TO_TRANSFORMERS, "models")
|
||||
_models = []
|
||||
@@ -413,7 +347,7 @@ def check_model_list():
|
||||
if os.path.isdir(model_dir) and "__init__.py" in os.listdir(model_dir):
|
||||
_models.append(model)
|
||||
|
||||
# Get the models from the directory structure of `src/transformers/models/`
|
||||
# Get the models in the submodule `transformers.models`
|
||||
models = [model for model in dir(transformers.models) if not model.startswith("__")]
|
||||
|
||||
missing_models = sorted(set(_models).difference(models))
|
||||
@@ -425,8 +359,8 @@ def check_model_list():
|
||||
|
||||
# If some modeling modules should be ignored for all checks, they should be added in the nested list
|
||||
# _ignore_modules of this function.
|
||||
def get_model_modules():
|
||||
"""Get the model modules inside the transformers library."""
|
||||
def get_model_modules() -> List[str]:
|
||||
"""Get all the model modules inside the transformers library (except deprecated models)."""
|
||||
_ignore_modules = [
|
||||
"modeling_auto",
|
||||
"modeling_encoder_decoder",
|
||||
@@ -454,21 +388,32 @@ def get_model_modules():
|
||||
]
|
||||
modules = []
|
||||
for model in dir(transformers.models):
|
||||
if model == "deprecated":
|
||||
continue
|
||||
# There are some magic dunder attributes in the dir, we ignore them
|
||||
if not model.startswith("__"):
|
||||
model_module = getattr(transformers.models, model)
|
||||
for submodule in dir(model_module):
|
||||
if submodule.startswith("modeling") and submodule not in _ignore_modules:
|
||||
modeling_module = getattr(model_module, submodule)
|
||||
if inspect.ismodule(modeling_module):
|
||||
modules.append(modeling_module)
|
||||
if model == "deprecated" or model.startswith("__"):
|
||||
continue
|
||||
|
||||
model_module = getattr(transformers.models, model)
|
||||
for submodule in dir(model_module):
|
||||
if submodule.startswith("modeling") and submodule not in _ignore_modules:
|
||||
modeling_module = getattr(model_module, submodule)
|
||||
if inspect.ismodule(modeling_module):
|
||||
modules.append(modeling_module)
|
||||
return modules
|
||||
|
||||
|
||||
def get_models(module, include_pretrained=False):
|
||||
"""Get the objects in module that are models."""
|
||||
def get_models(module: types.ModuleType, include_pretrained: bool = False) -> List[Tuple[str, type]]:
|
||||
"""
|
||||
Get the objects in a module that are models.
|
||||
|
||||
Args:
|
||||
module (`types.ModuleType`):
|
||||
The module from which we are extracting models.
|
||||
include_pretrained (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to include the `PreTrainedModel` subclass (like `BertPreTrainedModel`) or not.
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, type]]: List of models as tuples (class name, actual class).
|
||||
"""
|
||||
models = []
|
||||
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
|
||||
for attr_name in dir(module):
|
||||
@@ -480,12 +425,10 @@ def get_models(module, include_pretrained=False):
|
||||
return models
|
||||
|
||||
|
||||
def is_a_private_model(model):
|
||||
"""Returns True if the model should not be in the main init."""
|
||||
if model in PRIVATE_MODELS:
|
||||
return True
|
||||
|
||||
# Wrapper, Encoder and Decoder are all privates
|
||||
def is_building_block(model: str) -> bool:
|
||||
"""
|
||||
Returns `True` if a model is a building block part of a bigger model.
|
||||
"""
|
||||
if model.endswith("Wrapper"):
|
||||
return True
|
||||
if model.endswith("Encoder"):
|
||||
@@ -494,7 +437,13 @@ def is_a_private_model(model):
|
||||
return True
|
||||
if model.endswith("Prenet"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_a_private_model(model: str) -> bool:
|
||||
"""Returns `True` if the model should not be in the main init."""
|
||||
if model in PRIVATE_MODELS:
|
||||
return True
|
||||
return is_building_block(model)
|
||||
|
||||
|
||||
def check_models_are_in_init():
|
||||
@@ -514,11 +463,14 @@ def check_models_are_in_init():
|
||||
|
||||
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
|
||||
# nested list _ignore_files of this function.
|
||||
def get_model_test_files():
|
||||
"""Get the model test files.
|
||||
def get_model_test_files() -> List[str]:
|
||||
"""
|
||||
Get the model test files.
|
||||
|
||||
The returned files should NOT contain the `tests` (i.e. `PATH_TO_TESTS` defined in this script). They will be
|
||||
considered as paths relative to `tests`. A caller has to use `os.path.join(PATH_TO_TESTS, ...)` to access the files.
|
||||
Returns:
|
||||
`List[str]`: The list of test files. The returned files will NOT contain the `tests` (i.e. `PATH_TO_TESTS`
|
||||
defined in this script). They will be considered as paths relative to `tests`. A caller has to use
|
||||
`os.path.join(PATH_TO_TESTS, ...)` to access the files.
|
||||
"""
|
||||
|
||||
_ignore_files = [
|
||||
@@ -531,7 +483,6 @@ def get_model_test_files():
|
||||
"test_modeling_tf_encoder_decoder",
|
||||
]
|
||||
test_files = []
|
||||
# Check both `PATH_TO_TESTS` and `PATH_TO_TESTS/models`
|
||||
model_test_root = os.path.join(PATH_TO_TESTS, "models")
|
||||
model_test_dirs = []
|
||||
for x in os.listdir(model_test_root):
|
||||
@@ -553,9 +504,17 @@ def get_model_test_files():
|
||||
|
||||
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
|
||||
# for the all_model_classes variable.
|
||||
def find_tested_models(test_file):
|
||||
"""Parse the content of test_file to detect what's in all_model_classes"""
|
||||
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
|
||||
def find_tested_models(test_file: str) -> List[str]:
|
||||
"""
|
||||
Parse the content of test_file to detect what's in `all_model_classes`. This detects the models that inherit from
|
||||
the common test class.
|
||||
|
||||
Args:
|
||||
test_file (`str`): The path to the test file to check
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of models tested in that file.
|
||||
"""
|
||||
with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
|
||||
content = f.read()
|
||||
all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
|
||||
@@ -571,8 +530,25 @@ def find_tested_models(test_file):
|
||||
return model_tested
|
||||
|
||||
|
||||
def check_models_are_tested(module, test_file):
|
||||
"""Check models defined in module are tested in test_file."""
|
||||
def should_be_tested(model_name: str) -> bool:
|
||||
"""
|
||||
Whether or not a model should be tested.
|
||||
"""
|
||||
if model_name in IGNORE_NON_TESTED:
|
||||
return False
|
||||
return not is_building_block(model_name)
|
||||
|
||||
|
||||
def check_models_are_tested(module: types.ModuleType, test_file: str) -> List[str]:
|
||||
"""Check models defined in a module are all tested in a given file.
|
||||
|
||||
Args:
|
||||
module (`types.ModuleType`): The module in which we get the models.
|
||||
test_file (`str`): The path to the file where the module is tested.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of error messages corresponding to models not tested.
|
||||
"""
|
||||
# XxxPreTrainedModel are not tested
|
||||
defined_models = get_models(module)
|
||||
tested_models = find_tested_models(test_file)
|
||||
@@ -586,7 +562,7 @@ def check_models_are_tested(module, test_file):
|
||||
]
|
||||
failures = []
|
||||
for model_name, _ in defined_models:
|
||||
if model_name not in tested_models and model_name not in IGNORE_NON_TESTED:
|
||||
if model_name not in tested_models and should_be_tested(model_name):
|
||||
failures.append(
|
||||
f"{model_name} is defined in {module.__name__} but is not tested in "
|
||||
+ f"{os.path.join(PATH_TO_TESTS, test_file)}. Add it to the all_model_classes in that file."
|
||||
@@ -602,6 +578,7 @@ def check_all_models_are_tested():
|
||||
test_files = get_model_test_files()
|
||||
failures = []
|
||||
for module in modules:
|
||||
# Matches a module to its test file.
|
||||
test_file = [file for file in test_files if f"test_{module.__name__.split('.')[-1]}.py" in file]
|
||||
if len(test_file) == 0:
|
||||
failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.")
|
||||
@@ -616,7 +593,7 @@ def check_all_models_are_tested():
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
|
||||
|
||||
def get_all_auto_configured_models():
|
||||
def get_all_auto_configured_models() -> List[str]:
|
||||
"""Return the list of all models in at least one auto class."""
|
||||
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
||||
if is_torch_available():
|
||||
@@ -634,8 +611,8 @@ def get_all_auto_configured_models():
|
||||
return list(result)
|
||||
|
||||
|
||||
def ignore_unautoclassed(model_name):
|
||||
"""Rules to determine if `name` should be in an auto class."""
|
||||
def ignore_unautoclassed(model_name: str) -> bool:
|
||||
"""Rules to determine if a model should be in an auto class."""
|
||||
# Special white list
|
||||
if model_name in IGNORE_NON_AUTO_CONFIGURED:
|
||||
return True
|
||||
@@ -645,8 +622,19 @@ def ignore_unautoclassed(model_name):
|
||||
return False
|
||||
|
||||
|
||||
def check_models_are_auto_configured(module, all_auto_models):
|
||||
"""Check models defined in module are each in an auto class."""
|
||||
def check_models_are_auto_configured(module: types.ModuleType, all_auto_models: List[str]) -> List[str]:
|
||||
"""
|
||||
Check models defined in module are each in an auto class.
|
||||
|
||||
Args:
|
||||
module (`types.ModuleType`):
|
||||
The module in which we get the models.
|
||||
all_auto_models (`List[str]`):
|
||||
The list of all models in an auto class (as obtained with `get_all_auto_configured_models()`).
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of error messages corresponding to models not tested.
|
||||
"""
|
||||
defined_models = get_models(module)
|
||||
failures = []
|
||||
for model_name, _ in defined_models:
|
||||
@@ -661,6 +649,7 @@ def check_models_are_auto_configured(module, all_auto_models):
|
||||
|
||||
def check_all_models_are_auto_configured():
|
||||
"""Check all models are each in an auto class."""
|
||||
# This is where we need to check we have all backends or the check is incomplete.
|
||||
check_missing_backends()
|
||||
modules = get_model_modules()
|
||||
all_auto_models = get_all_auto_configured_models()
|
||||
@@ -675,6 +664,7 @@ def check_all_models_are_auto_configured():
|
||||
|
||||
def check_all_auto_object_names_being_defined():
|
||||
"""Check all names defined in auto (name) mappings exist in the library."""
|
||||
# This is where we need to check we have all backends or the check is incomplete.
|
||||
check_missing_backends()
|
||||
|
||||
failures = []
|
||||
@@ -695,7 +685,7 @@ def check_all_auto_object_names_being_defined():
|
||||
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
|
||||
|
||||
for name, mapping in mappings_to_check.items():
|
||||
for model_type, class_names in mapping.items():
|
||||
for _, class_names in mapping.items():
|
||||
if not isinstance(class_names, tuple):
|
||||
class_names = (class_names,)
|
||||
for class_name in class_names:
|
||||
@@ -716,6 +706,7 @@ def check_all_auto_object_names_being_defined():
|
||||
|
||||
def check_all_auto_mapping_names_in_config_mapping_names():
|
||||
"""Check all keys defined in auto mappings (mappings of names) appear in `CONFIG_MAPPING_NAMES`."""
|
||||
# This is where we need to check we have all backends or the check is incomplete.
|
||||
check_missing_backends()
|
||||
|
||||
failures = []
|
||||
@@ -736,7 +727,7 @@ def check_all_auto_mapping_names_in_config_mapping_names():
|
||||
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
|
||||
|
||||
for name, mapping in mappings_to_check.items():
|
||||
for model_type, class_names in mapping.items():
|
||||
for model_type in mapping:
|
||||
if model_type not in CONFIG_MAPPING_NAMES:
|
||||
failures.append(
|
||||
f"`{model_type}` appears in the mapping `{name}` but it is not defined in the keys of "
|
||||
@@ -747,7 +738,8 @@ def check_all_auto_mapping_names_in_config_mapping_names():
|
||||
|
||||
|
||||
def check_all_auto_mappings_importable():
|
||||
"""Check all auto mappings could be imported."""
|
||||
"""Check all auto mappings can be imported."""
|
||||
# This is where we need to check we have all backends or the check is incomplete.
|
||||
check_missing_backends()
|
||||
|
||||
failures = []
|
||||
@@ -761,7 +753,7 @@ def check_all_auto_mappings_importable():
|
||||
mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")]
|
||||
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
|
||||
|
||||
for name, _ in mappings_to_check.items():
|
||||
for name in mappings_to_check:
|
||||
name = name.replace("_MAPPING_NAMES", "_MAPPING")
|
||||
if not hasattr(transformers, name):
|
||||
failures.append(f"`{name}`")
|
||||
@@ -770,44 +762,46 @@ def check_all_auto_mappings_importable():
|
||||
|
||||
|
||||
def check_objects_being_equally_in_main_init():
|
||||
"""Check if an object is in the main __init__ if its counterpart in PyTorch is."""
|
||||
"""
|
||||
Check if a (TensorFlow or Flax) object is in the main __init__ iif its counterpart in PyTorch is.
|
||||
"""
|
||||
attrs = dir(transformers)
|
||||
|
||||
failures = []
|
||||
for attr in attrs:
|
||||
obj = getattr(transformers, attr)
|
||||
if hasattr(obj, "__module__"):
|
||||
module_path = obj.__module__
|
||||
if "models.deprecated" in module_path:
|
||||
continue
|
||||
module_name = module_path.split(".")[-1]
|
||||
module_dir = ".".join(module_path.split(".")[:-1])
|
||||
if (
|
||||
module_name.startswith("modeling_")
|
||||
and not module_name.startswith("modeling_tf_")
|
||||
and not module_name.startswith("modeling_flax_")
|
||||
):
|
||||
parent_module = sys.modules[module_dir]
|
||||
if not hasattr(obj, "__module__") or "models.deprecated" in obj.__module__:
|
||||
continue
|
||||
|
||||
frameworks = []
|
||||
if is_tf_available():
|
||||
frameworks.append("TF")
|
||||
if is_flax_available():
|
||||
frameworks.append("Flax")
|
||||
module_path = obj.__module__
|
||||
module_name = module_path.split(".")[-1]
|
||||
module_dir = ".".join(module_path.split(".")[:-1])
|
||||
if (
|
||||
module_name.startswith("modeling_")
|
||||
and not module_name.startswith("modeling_tf_")
|
||||
and not module_name.startswith("modeling_flax_")
|
||||
):
|
||||
parent_module = sys.modules[module_dir]
|
||||
|
||||
for framework in frameworks:
|
||||
other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_")
|
||||
if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"):
|
||||
other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_")
|
||||
other_module = getattr(parent_module, other_module_name)
|
||||
if hasattr(other_module, f"{framework}{attr}"):
|
||||
if not hasattr(transformers, f"{framework}{attr}"):
|
||||
if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
|
||||
failures.append(f"{framework}{attr}")
|
||||
if hasattr(other_module, f"{framework}_{attr}"):
|
||||
if not hasattr(transformers, f"{framework}_{attr}"):
|
||||
if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
|
||||
failures.append(f"{framework}_{attr}")
|
||||
frameworks = []
|
||||
if is_tf_available():
|
||||
frameworks.append("TF")
|
||||
if is_flax_available():
|
||||
frameworks.append("Flax")
|
||||
|
||||
for framework in frameworks:
|
||||
other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_")
|
||||
if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"):
|
||||
other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_")
|
||||
other_module = getattr(parent_module, other_module_name)
|
||||
if hasattr(other_module, f"{framework}{attr}"):
|
||||
if not hasattr(transformers, f"{framework}{attr}"):
|
||||
if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
|
||||
failures.append(f"{framework}{attr}")
|
||||
if hasattr(other_module, f"{framework}_{attr}"):
|
||||
if not hasattr(transformers, f"{framework}_{attr}"):
|
||||
if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
|
||||
failures.append(f"{framework}_{attr}")
|
||||
if len(failures) > 0:
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
|
||||
@@ -815,8 +809,16 @@ def check_objects_being_equally_in_main_init():
|
||||
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
|
||||
|
||||
|
||||
def check_decorator_order(filename):
|
||||
"""Check that in the test file `filename` the slow decorator is always last."""
|
||||
def check_decorator_order(filename: str) -> List[int]:
|
||||
"""
|
||||
Check that in a given test file, the slow decorator is always last.
|
||||
|
||||
Args:
|
||||
filename (`str`): The path to a test file to check.
|
||||
|
||||
Returns:
|
||||
`List[int]`: The list of failures as a list of indices where there are problems.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
decorator_before = None
|
||||
@@ -849,8 +851,13 @@ def check_all_decorator_order():
|
||||
)
|
||||
|
||||
|
||||
def find_all_documented_objects():
|
||||
"""Parse the content of all doc files to detect which classes and functions it documents"""
|
||||
def find_all_documented_objects() -> List[str]:
|
||||
"""
|
||||
Parse the content of all doc files to detect which classes and functions it documents.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of all object names being documented.
|
||||
"""
|
||||
documented_obj = []
|
||||
for doc_file in Path(PATH_TO_DOC).glob("**/*.rst"):
|
||||
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
@@ -959,8 +966,8 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
|
||||
]
|
||||
|
||||
|
||||
def ignore_undocumented(name):
|
||||
"""Rules to determine if `name` should be undocumented."""
|
||||
def ignore_undocumented(name: str) -> bool:
|
||||
"""Rules to determine if `name` should be undocumented (returns `True` if it should not be documented)."""
|
||||
# NOT DOCUMENTED ON PURPOSE.
|
||||
# Constants uppercase are not documented.
|
||||
if name.isupper():
|
||||
@@ -1047,7 +1054,7 @@ _re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
|
||||
_re_rst_example = re.compile(r"^\s*Example.*::\s*$", flags=re.MULTILINE)
|
||||
|
||||
|
||||
def is_rst_docstring(docstring):
|
||||
def is_rst_docstring(docstring: str) -> True:
|
||||
"""
|
||||
Returns `True` if `docstring` is written in rst.
|
||||
"""
|
||||
@@ -1061,7 +1068,7 @@ def is_rst_docstring(docstring):
|
||||
|
||||
|
||||
def check_docstrings_are_in_md():
|
||||
"""Check all docstrings are in md"""
|
||||
"""Check all docstrings are written in md and nor rst."""
|
||||
files_with_rst = []
|
||||
for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"):
|
||||
with open(file, encoding="utf-8") as f:
|
||||
@@ -1084,6 +1091,9 @@ def check_docstrings_are_in_md():
|
||||
|
||||
|
||||
def check_deprecated_constant_is_up_to_date():
|
||||
"""
|
||||
Check if the constant `DEPRECATED_MODELS` in `models/auto/configuration_auto.py` is up to date.
|
||||
"""
|
||||
deprecated_folder = os.path.join(PATH_TO_TRANSFORMERS, "models", "deprecated")
|
||||
deprecated_models = [m for m in os.listdir(deprecated_folder) if not m.startswith("_")]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user