Doc checks (#25408)

* Document check_dummies

* Type hints and doc in other files

* Document check inits

* Add documentation to

* Address review comments
This commit is contained in:
Sylvain Gugger
2023-08-10 10:53:22 +02:00
committed by GitHub
parent b14d4641f6
commit 16edf4d9fd
6 changed files with 459 additions and 224 deletions

View File

@@ -12,15 +12,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility that performs several consistency checks on the repo. This includes:
- checking all models are properly defined in the __init__ of models/
- checking all models are in the main __init__
- checking all models are properly tested
- checking all object in the main __init__ are documented
- checking all models are in at least one auto class
- checking all the auto mapping are properly defined (no typos, importable)
- checking the list of deprecated models is up to date
Use from the root of the repo with (as used in `make repo-consistency`):
```bash
python utils/check_repo.py
```
It has no auto-fix mode.
"""
import inspect
import os
import re
import sys
import types
import warnings
from collections import OrderedDict
from difflib import get_close_matches
from pathlib import Path
from typing import List, Tuple
from transformers import is_flax_available, is_tf_available, is_torch_available
from transformers.models.auto import get_values
@@ -60,91 +79,25 @@ PRIVATE_MODELS = [
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested
"InstructBlipQFormerModel", # Building part of bigger (tested) model.
"NllbMoeDecoder",
"NllbMoeEncoder",
"UMT5EncoderModel", # Building part of bigger (tested) model.
"LlamaDecoder", # Building part of bigger (tested) model.
"Blip2QFormerModel", # Building part of bigger (tested) model.
"DetaEncoder", # Building part of bigger (tested) model.
"DetaDecoder", # Building part of bigger (tested) model.
"ErnieMForInformationExtraction",
"GraphormerEncoder", # Building part of bigger (tested) model.
"GraphormerDecoderHead", # Building part of bigger (tested) model.
"CLIPSegDecoder", # Building part of bigger (tested) model.
"TableTransformerEncoder", # Building part of bigger (tested) model.
"TableTransformerDecoder", # Building part of bigger (tested) model.
"TimeSeriesTransformerEncoder", # Building part of bigger (tested) model.
"TimeSeriesTransformerDecoder", # Building part of bigger (tested) model.
"InformerEncoder", # Building part of bigger (tested) model.
"InformerDecoder", # Building part of bigger (tested) model.
"AutoformerEncoder", # Building part of bigger (tested) model.
"AutoformerDecoder", # Building part of bigger (tested) model.
"JukeboxVQVAE", # Building part of bigger (tested) model.
"JukeboxPrior", # Building part of bigger (tested) model.
"DeformableDetrEncoder", # Building part of bigger (tested) model.
"DeformableDetrDecoder", # Building part of bigger (tested) model.
"OPTDecoder", # Building part of bigger (tested) model.
"FlaxWhisperDecoder", # Building part of bigger (tested) model.
"FlaxWhisperEncoder", # Building part of bigger (tested) model.
"WhisperDecoder", # Building part of bigger (tested) model.
"WhisperEncoder", # Building part of bigger (tested) model.
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
"SegformerDecodeHead", # Building part of bigger (tested) model.
"PLBartEncoder", # Building part of bigger (tested) model.
"PLBartDecoder", # Building part of bigger (tested) model.
"PLBartDecoderWrapper", # Building part of bigger (tested) model.
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
"BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model.
"DetrEncoder", # Building part of bigger (tested) model.
"DetrDecoder", # Building part of bigger (tested) model.
"DetrDecoderWrapper", # Building part of bigger (tested) model.
"ConditionalDetrEncoder", # Building part of bigger (tested) model.
"ConditionalDetrDecoder", # Building part of bigger (tested) model.
"M2M100Encoder", # Building part of bigger (tested) model.
"M2M100Decoder", # Building part of bigger (tested) model.
"MCTCTEncoder", # Building part of bigger (tested) model.
"MgpstrModel", # Building part of bigger (tested) model.
"Speech2TextEncoder", # Building part of bigger (tested) model.
"Speech2TextDecoder", # Building part of bigger (tested) model.
"LEDEncoder", # Building part of bigger (tested) model.
"LEDDecoder", # Building part of bigger (tested) model.
"BartDecoderWrapper", # Building part of bigger (tested) model.
"BartEncoder", # Building part of bigger (tested) model.
"BertLMHeadModel", # Needs to be setup as decoder.
"BlenderbotSmallEncoder", # Building part of bigger (tested) model.
"BlenderbotSmallDecoderWrapper", # Building part of bigger (tested) model.
"BlenderbotEncoder", # Building part of bigger (tested) model.
"BlenderbotDecoderWrapper", # Building part of bigger (tested) model.
"MBartEncoder", # Building part of bigger (tested) model.
"MBartDecoderWrapper", # Building part of bigger (tested) model.
"MegatronBertLMHeadModel", # Building part of bigger (tested) model.
"MegatronBertEncoder", # Building part of bigger (tested) model.
"MegatronBertDecoder", # Building part of bigger (tested) model.
"MegatronBertDecoderWrapper", # Building part of bigger (tested) model.
"MusicgenDecoder", # Building part of bigger (tested) model.
"MvpDecoderWrapper", # Building part of bigger (tested) model.
"MvpEncoder", # Building part of bigger (tested) model.
"PegasusEncoder", # Building part of bigger (tested) model.
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
"PegasusXEncoder", # Building part of bigger (tested) model.
"PegasusXDecoder", # Building part of bigger (tested) model.
"PegasusXDecoderWrapper", # Building part of bigger (tested) model.
"DPREncoder", # Building part of bigger (tested) model.
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
"RealmBertModel", # Building part of bigger (tested) model.
"RealmReader", # Not regular model.
"RealmScorer", # Not regular model.
"RealmForOpenQA", # Not regular model.
"ReformerForMaskedLM", # Needs to be setup as decoder.
"Speech2Text2DecoderWrapper", # Building part of bigger (tested) model.
"TFDPREncoder", # Building part of bigger (tested) model.
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
"TFRobertaForMultipleChoice", # TODO: fix
"TFRobertaPreLayerNormForMultipleChoice", # TODO: fix
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
"TFWhisperEncoder", # Building part of bigger (tested) model.
"TFWhisperDecoder", # Building part of bigger (tested) model.
"SeparableConv1D", # Building part of bigger (tested) model.
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
@@ -155,18 +108,6 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"TFBlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models
"BridgeTowerTextModel", # No need to test it as it is tested by BridgeTowerModel model.
"BridgeTowerVisionModel", # No need to test it as it is tested by BridgeTowerModel model.
"SpeechT5Decoder", # Building part of bigger (tested) model.
"SpeechT5DecoderWithoutPrenet", # Building part of bigger (tested) model.
"SpeechT5DecoderWithSpeechPrenet", # Building part of bigger (tested) model.
"SpeechT5DecoderWithTextPrenet", # Building part of bigger (tested) model.
"SpeechT5Encoder", # Building part of bigger (tested) model.
"SpeechT5EncoderWithoutPrenet", # Building part of bigger (tested) model.
"SpeechT5EncoderWithSpeechPrenet", # Building part of bigger (tested) model.
"SpeechT5EncoderWithTextPrenet", # Building part of bigger (tested) model.
"SpeechT5SpeechDecoder", # Building part of bigger (tested) model.
"SpeechT5SpeechEncoder", # Building part of bigger (tested) model.
"SpeechT5TextDecoder", # Building part of bigger (tested) model.
"SpeechT5TextEncoder", # Building part of bigger (tested) model.
"BarkCausalModel", # Building part of bigger (tested) model.
"BarkModel", # Does not have a forward signature - generation tested with integration tests
]
@@ -236,12 +177,6 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"AutoformerForPrediction",
"JukeboxVQVAE",
"JukeboxPrior",
"PegasusXEncoder",
"PegasusXDecoder",
"PegasusXDecoderWrapper",
"PegasusXEncoder",
"PegasusXDecoder",
"PegasusXDecoderWrapper",
"SamModel",
"DPTForDepthEstimation",
"DecisionTransformerGPT2Model",
@@ -250,17 +185,11 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"ViltForImageAndTextRetrieval",
"ViltForTokenClassification",
"ViltForMaskedLM",
"XGLMEncoder",
"XGLMDecoder",
"XGLMDecoderWrapper",
"PerceiverForMultimodalAutoencoding",
"PerceiverForOpticalFlow",
"SegformerDecodeHead",
"TFSegformerDecodeHead",
"FlaxBeitForMaskedImageModeling",
"PLBartEncoder",
"PLBartDecoder",
"PLBartDecoderWrapper",
"BeitForMaskedImageModeling",
"ChineseCLIPTextModel",
"ChineseCLIPVisionModel",
@@ -347,7 +276,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
]
# DO NOT edit this list!
# (The corresponding pytorch objects should never be in the main `__init__`, but it's too late to remove)
# (The corresponding pytorch objects should never have been in the main `__init__`, but it's too late to remove)
OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK = [
"FlaxBertLayer",
"FlaxBigBirdLayer",
@@ -361,8 +290,7 @@ OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK = [
"TFViTMAELayer",
]
# Update this list for models that have multiple model types for the same
# model doc
# Update this list for models that have multiple model types for the same model doc.
MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
[
("data2vec-text", "data2vec"),
@@ -378,6 +306,10 @@ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
def check_missing_backends():
"""
Checks if all backends are installed (otherwise the check of this script is incomplete). Will error in the CI if
that's not the case but only throw a warning for users running this.
"""
missing_backends = []
if not is_torch_available():
missing_backends.append("PyTorch")
@@ -402,7 +334,9 @@ def check_missing_backends():
def check_model_list():
"""Check the model list inside the transformers library."""
"""
Checks the model listed as subfolders of `models` match the models available in `transformers.models`.
"""
# Get the models from the directory structure of `src/transformers/models/`
models_dir = os.path.join(PATH_TO_TRANSFORMERS, "models")
_models = []
@@ -413,7 +347,7 @@ def check_model_list():
if os.path.isdir(model_dir) and "__init__.py" in os.listdir(model_dir):
_models.append(model)
# Get the models from the directory structure of `src/transformers/models/`
# Get the models in the submodule `transformers.models`
models = [model for model in dir(transformers.models) if not model.startswith("__")]
missing_models = sorted(set(_models).difference(models))
@@ -425,8 +359,8 @@ def check_model_list():
# If some modeling modules should be ignored for all checks, they should be added in the nested list
# _ignore_modules of this function.
def get_model_modules():
"""Get the model modules inside the transformers library."""
def get_model_modules() -> List[str]:
"""Get all the model modules inside the transformers library (except deprecated models)."""
_ignore_modules = [
"modeling_auto",
"modeling_encoder_decoder",
@@ -454,21 +388,32 @@ def get_model_modules():
]
modules = []
for model in dir(transformers.models):
if model == "deprecated":
continue
# There are some magic dunder attributes in the dir, we ignore them
if not model.startswith("__"):
model_module = getattr(transformers.models, model)
for submodule in dir(model_module):
if submodule.startswith("modeling") and submodule not in _ignore_modules:
modeling_module = getattr(model_module, submodule)
if inspect.ismodule(modeling_module):
modules.append(modeling_module)
if model == "deprecated" or model.startswith("__"):
continue
model_module = getattr(transformers.models, model)
for submodule in dir(model_module):
if submodule.startswith("modeling") and submodule not in _ignore_modules:
modeling_module = getattr(model_module, submodule)
if inspect.ismodule(modeling_module):
modules.append(modeling_module)
return modules
def get_models(module, include_pretrained=False):
"""Get the objects in module that are models."""
def get_models(module: types.ModuleType, include_pretrained: bool = False) -> List[Tuple[str, type]]:
"""
Get the objects in a module that are models.
Args:
module (`types.ModuleType`):
The module from which we are extracting models.
include_pretrained (`bool`, *optional*, defaults to `False`):
Whether or not to include the `PreTrainedModel` subclass (like `BertPreTrainedModel`) or not.
Returns:
List[Tuple[str, type]]: List of models as tuples (class name, actual class).
"""
models = []
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
for attr_name in dir(module):
@@ -480,12 +425,10 @@ def get_models(module, include_pretrained=False):
return models
def is_a_private_model(model):
"""Returns True if the model should not be in the main init."""
if model in PRIVATE_MODELS:
return True
# Wrapper, Encoder and Decoder are all privates
def is_building_block(model: str) -> bool:
"""
Returns `True` if a model is a building block part of a bigger model.
"""
if model.endswith("Wrapper"):
return True
if model.endswith("Encoder"):
@@ -494,7 +437,13 @@ def is_a_private_model(model):
return True
if model.endswith("Prenet"):
return True
return False
def is_a_private_model(model: str) -> bool:
"""Returns `True` if the model should not be in the main init."""
if model in PRIVATE_MODELS:
return True
return is_building_block(model)
def check_models_are_in_init():
@@ -514,11 +463,14 @@ def check_models_are_in_init():
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
# nested list _ignore_files of this function.
def get_model_test_files():
"""Get the model test files.
def get_model_test_files() -> List[str]:
"""
Get the model test files.
The returned files should NOT contain the `tests` (i.e. `PATH_TO_TESTS` defined in this script). They will be
considered as paths relative to `tests`. A caller has to use `os.path.join(PATH_TO_TESTS, ...)` to access the files.
Returns:
`List[str]`: The list of test files. The returned files will NOT contain the `tests` (i.e. `PATH_TO_TESTS`
defined in this script). They will be considered as paths relative to `tests`. A caller has to use
`os.path.join(PATH_TO_TESTS, ...)` to access the files.
"""
_ignore_files = [
@@ -531,7 +483,6 @@ def get_model_test_files():
"test_modeling_tf_encoder_decoder",
]
test_files = []
# Check both `PATH_TO_TESTS` and `PATH_TO_TESTS/models`
model_test_root = os.path.join(PATH_TO_TESTS, "models")
model_test_dirs = []
for x in os.listdir(model_test_root):
@@ -553,9 +504,17 @@ def get_model_test_files():
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
# for the all_model_classes variable.
def find_tested_models(test_file):
"""Parse the content of test_file to detect what's in all_model_classes"""
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
def find_tested_models(test_file: str) -> List[str]:
"""
Parse the content of test_file to detect what's in `all_model_classes`. This detects the models that inherit from
the common test class.
Args:
test_file (`str`): The path to the test file to check
Returns:
`List[str]`: The list of models tested in that file.
"""
with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
content = f.read()
all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
@@ -571,8 +530,25 @@ def find_tested_models(test_file):
return model_tested
def check_models_are_tested(module, test_file):
"""Check models defined in module are tested in test_file."""
def should_be_tested(model_name: str) -> bool:
"""
Whether or not a model should be tested.
"""
if model_name in IGNORE_NON_TESTED:
return False
return not is_building_block(model_name)
def check_models_are_tested(module: types.ModuleType, test_file: str) -> List[str]:
"""Check models defined in a module are all tested in a given file.
Args:
module (`types.ModuleType`): The module in which we get the models.
test_file (`str`): The path to the file where the module is tested.
Returns:
`List[str]`: The list of error messages corresponding to models not tested.
"""
# XxxPreTrainedModel are not tested
defined_models = get_models(module)
tested_models = find_tested_models(test_file)
@@ -586,7 +562,7 @@ def check_models_are_tested(module, test_file):
]
failures = []
for model_name, _ in defined_models:
if model_name not in tested_models and model_name not in IGNORE_NON_TESTED:
if model_name not in tested_models and should_be_tested(model_name):
failures.append(
f"{model_name} is defined in {module.__name__} but is not tested in "
+ f"{os.path.join(PATH_TO_TESTS, test_file)}. Add it to the all_model_classes in that file."
@@ -602,6 +578,7 @@ def check_all_models_are_tested():
test_files = get_model_test_files()
failures = []
for module in modules:
# Matches a module to its test file.
test_file = [file for file in test_files if f"test_{module.__name__.split('.')[-1]}.py" in file]
if len(test_file) == 0:
failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.")
@@ -616,7 +593,7 @@ def check_all_models_are_tested():
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
def get_all_auto_configured_models():
def get_all_auto_configured_models() -> List[str]:
"""Return the list of all models in at least one auto class."""
result = set() # To avoid duplicates we concatenate all model classes in a set.
if is_torch_available():
@@ -634,8 +611,8 @@ def get_all_auto_configured_models():
return list(result)
def ignore_unautoclassed(model_name):
"""Rules to determine if `name` should be in an auto class."""
def ignore_unautoclassed(model_name: str) -> bool:
"""Rules to determine if a model should be in an auto class."""
# Special white list
if model_name in IGNORE_NON_AUTO_CONFIGURED:
return True
@@ -645,8 +622,19 @@ def ignore_unautoclassed(model_name):
return False
def check_models_are_auto_configured(module, all_auto_models):
"""Check models defined in module are each in an auto class."""
def check_models_are_auto_configured(module: types.ModuleType, all_auto_models: List[str]) -> List[str]:
"""
Check models defined in module are each in an auto class.
Args:
module (`types.ModuleType`):
The module in which we get the models.
all_auto_models (`List[str]`):
The list of all models in an auto class (as obtained with `get_all_auto_configured_models()`).
Returns:
`List[str]`: The list of error messages corresponding to models not tested.
"""
defined_models = get_models(module)
failures = []
for model_name, _ in defined_models:
@@ -661,6 +649,7 @@ def check_models_are_auto_configured(module, all_auto_models):
def check_all_models_are_auto_configured():
"""Check all models are each in an auto class."""
# This is where we need to check we have all backends or the check is incomplete.
check_missing_backends()
modules = get_model_modules()
all_auto_models = get_all_auto_configured_models()
@@ -675,6 +664,7 @@ def check_all_models_are_auto_configured():
def check_all_auto_object_names_being_defined():
"""Check all names defined in auto (name) mappings exist in the library."""
# This is where we need to check we have all backends or the check is incomplete.
check_missing_backends()
failures = []
@@ -695,7 +685,7 @@ def check_all_auto_object_names_being_defined():
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
for name, mapping in mappings_to_check.items():
for model_type, class_names in mapping.items():
for _, class_names in mapping.items():
if not isinstance(class_names, tuple):
class_names = (class_names,)
for class_name in class_names:
@@ -716,6 +706,7 @@ def check_all_auto_object_names_being_defined():
def check_all_auto_mapping_names_in_config_mapping_names():
"""Check all keys defined in auto mappings (mappings of names) appear in `CONFIG_MAPPING_NAMES`."""
# This is where we need to check we have all backends or the check is incomplete.
check_missing_backends()
failures = []
@@ -736,7 +727,7 @@ def check_all_auto_mapping_names_in_config_mapping_names():
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
for name, mapping in mappings_to_check.items():
for model_type, class_names in mapping.items():
for model_type in mapping:
if model_type not in CONFIG_MAPPING_NAMES:
failures.append(
f"`{model_type}` appears in the mapping `{name}` but it is not defined in the keys of "
@@ -747,7 +738,8 @@ def check_all_auto_mapping_names_in_config_mapping_names():
def check_all_auto_mappings_importable():
"""Check all auto mappings could be imported."""
"""Check all auto mappings can be imported."""
# This is where we need to check we have all backends or the check is incomplete.
check_missing_backends()
failures = []
@@ -761,7 +753,7 @@ def check_all_auto_mappings_importable():
mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")]
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
for name, _ in mappings_to_check.items():
for name in mappings_to_check:
name = name.replace("_MAPPING_NAMES", "_MAPPING")
if not hasattr(transformers, name):
failures.append(f"`{name}`")
@@ -770,44 +762,46 @@ def check_all_auto_mappings_importable():
def check_objects_being_equally_in_main_init():
"""Check if an object is in the main __init__ if its counterpart in PyTorch is."""
"""
Check if a (TensorFlow or Flax) object is in the main __init__ iif its counterpart in PyTorch is.
"""
attrs = dir(transformers)
failures = []
for attr in attrs:
obj = getattr(transformers, attr)
if hasattr(obj, "__module__"):
module_path = obj.__module__
if "models.deprecated" in module_path:
continue
module_name = module_path.split(".")[-1]
module_dir = ".".join(module_path.split(".")[:-1])
if (
module_name.startswith("modeling_")
and not module_name.startswith("modeling_tf_")
and not module_name.startswith("modeling_flax_")
):
parent_module = sys.modules[module_dir]
if not hasattr(obj, "__module__") or "models.deprecated" in obj.__module__:
continue
frameworks = []
if is_tf_available():
frameworks.append("TF")
if is_flax_available():
frameworks.append("Flax")
module_path = obj.__module__
module_name = module_path.split(".")[-1]
module_dir = ".".join(module_path.split(".")[:-1])
if (
module_name.startswith("modeling_")
and not module_name.startswith("modeling_tf_")
and not module_name.startswith("modeling_flax_")
):
parent_module = sys.modules[module_dir]
for framework in frameworks:
other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_")
if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"):
other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_")
other_module = getattr(parent_module, other_module_name)
if hasattr(other_module, f"{framework}{attr}"):
if not hasattr(transformers, f"{framework}{attr}"):
if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
failures.append(f"{framework}{attr}")
if hasattr(other_module, f"{framework}_{attr}"):
if not hasattr(transformers, f"{framework}_{attr}"):
if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
failures.append(f"{framework}_{attr}")
frameworks = []
if is_tf_available():
frameworks.append("TF")
if is_flax_available():
frameworks.append("Flax")
for framework in frameworks:
other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_")
if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"):
other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_")
other_module = getattr(parent_module, other_module_name)
if hasattr(other_module, f"{framework}{attr}"):
if not hasattr(transformers, f"{framework}{attr}"):
if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
failures.append(f"{framework}{attr}")
if hasattr(other_module, f"{framework}_{attr}"):
if not hasattr(transformers, f"{framework}_{attr}"):
if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
failures.append(f"{framework}_{attr}")
if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
@@ -815,8 +809,16 @@ def check_objects_being_equally_in_main_init():
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
def check_decorator_order(filename):
"""Check that in the test file `filename` the slow decorator is always last."""
def check_decorator_order(filename: str) -> List[int]:
"""
Check that in a given test file, the slow decorator is always last.
Args:
filename (`str`): The path to a test file to check.
Returns:
`List[int]`: The list of failures as a list of indices where there are problems.
"""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
decorator_before = None
@@ -849,8 +851,13 @@ def check_all_decorator_order():
)
def find_all_documented_objects():
"""Parse the content of all doc files to detect which classes and functions it documents"""
def find_all_documented_objects() -> List[str]:
"""
Parse the content of all doc files to detect which classes and functions it documents.
Returns:
`List[str]`: The list of all object names being documented.
"""
documented_obj = []
for doc_file in Path(PATH_TO_DOC).glob("**/*.rst"):
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
@@ -959,8 +966,8 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
]
def ignore_undocumented(name):
"""Rules to determine if `name` should be undocumented."""
def ignore_undocumented(name: str) -> bool:
"""Rules to determine if `name` should be undocumented (returns `True` if it should not be documented)."""
# NOT DOCUMENTED ON PURPOSE.
# Constants uppercase are not documented.
if name.isupper():
@@ -1047,7 +1054,7 @@ _re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
_re_rst_example = re.compile(r"^\s*Example.*::\s*$", flags=re.MULTILINE)
def is_rst_docstring(docstring):
def is_rst_docstring(docstring: str) -> True:
"""
Returns `True` if `docstring` is written in rst.
"""
@@ -1061,7 +1068,7 @@ def is_rst_docstring(docstring):
def check_docstrings_are_in_md():
"""Check all docstrings are in md"""
"""Check all docstrings are written in md and nor rst."""
files_with_rst = []
for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"):
with open(file, encoding="utf-8") as f:
@@ -1084,6 +1091,9 @@ def check_docstrings_are_in_md():
def check_deprecated_constant_is_up_to_date():
"""
Check if the constant `DEPRECATED_MODELS` in `models/auto/configuration_auto.py` is up to date.
"""
deprecated_folder = os.path.join(PATH_TO_TRANSFORMERS, "models", "deprecated")
deprecated_models = [m for m in os.listdir(deprecated_folder) if not m.startswith("_")]