make style (#11442)

This commit is contained in:
Patrick von Platen
2021-04-26 13:50:34 +02:00
committed by GitHub
parent 04ab2ca639
commit 32dbb2d954
105 changed files with 202 additions and 202 deletions

View File

@@ -119,7 +119,7 @@ transformers = spec.loader.load_module()
# If some modeling modules should be ignored for all checks, they should be added in the nested list
# _ignore_modules of this function.
def get_model_modules():
""" Get the model modules inside the transformers library. """
"""Get the model modules inside the transformers library."""
_ignore_modules = [
"modeling_auto",
"modeling_encoder_decoder",
@@ -151,7 +151,7 @@ def get_model_modules():
def get_models(module):
""" Get the objects in module that are models."""
"""Get the objects in module that are models."""
models = []
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
for attr_name in dir(module):
@@ -166,7 +166,7 @@ def get_models(module):
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
# nested list _ignore_files of this function.
def get_model_test_files():
""" Get the model test files."""
"""Get the model test files."""
_ignore_files = [
"test_modeling_common",
"test_modeling_encoder_decoder",
@@ -187,7 +187,7 @@ def get_model_test_files():
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
# for the all_model_classes variable.
def find_tested_models(test_file):
""" Parse the content of test_file to detect what's in all_model_classes"""
"""Parse the content of test_file to detect what's in all_model_classes"""
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
content = f.read()
@@ -205,7 +205,7 @@ def find_tested_models(test_file):
def check_models_are_tested(module, test_file):
""" Check models defined in module are tested in test_file."""
"""Check models defined in module are tested in test_file."""
defined_models = get_models(module)
tested_models = find_tested_models(test_file)
if tested_models is None:
@@ -229,7 +229,7 @@ def check_models_are_tested(module, test_file):
def check_all_models_are_tested():
""" Check all models are properly tested."""
"""Check all models are properly tested."""
modules = get_model_modules()
test_files = get_model_test_files()
failures = []
@@ -245,7 +245,7 @@ def check_all_models_are_tested():
def get_all_auto_configured_models():
""" Return the list of all models in at least one auto class."""
"""Return the list of all models in at least one auto class."""
result = set() # To avoid duplicates we concatenate all model classes in a set.
for attr_name in dir(transformers.models.auto.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
@@ -271,7 +271,7 @@ def ignore_unautoclassed(model_name):
def check_models_are_auto_configured(module, all_auto_models):
""" Check models defined in module are each in an auto class."""
"""Check models defined in module are each in an auto class."""
defined_models = get_models(module)
failures = []
for model_name, _ in defined_models:
@@ -285,7 +285,7 @@ def check_models_are_auto_configured(module, all_auto_models):
def check_all_models_are_auto_configured():
""" Check all models are each in an auto class."""
"""Check all models are each in an auto class."""
modules = get_model_modules()
all_auto_models = get_all_auto_configured_models()
failures = []
@@ -301,7 +301,7 @@ _re_decorator = re.compile(r"^\s*@(\S+)\s+$")
def check_decorator_order(filename):
""" Check that in the test file `filename` the slow decorator is always last."""
"""Check that in the test file `filename` the slow decorator is always last."""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
decorator_before = None
@@ -319,7 +319,7 @@ def check_decorator_order(filename):
def check_all_decorator_order():
""" Check that in all test files, the slow decorator is always last."""
"""Check that in all test files, the slow decorator is always last."""
errors = []
for fname in os.listdir(PATH_TO_TESTS):
if fname.endswith(".py"):
@@ -334,7 +334,7 @@ def check_all_decorator_order():
def find_all_documented_objects():
""" Parse the content of all doc files to detect which classes and functions it documents"""
"""Parse the content of all doc files to detect which classes and functions it documents"""
documented_obj = []
for doc_file in Path(PATH_TO_DOC).glob("**/*.rst"):
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
@@ -454,7 +454,7 @@ def ignore_undocumented(name):
def check_all_objects_are_documented():
""" Check all models are properly documented."""
"""Check all models are properly documented."""
documented_objs = find_all_documented_objects()
modules = transformers._modules
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
@@ -467,7 +467,7 @@ def check_all_objects_are_documented():
def check_repo_quality():
""" Check all models are properly tested and documented."""
"""Check all models are properly tested and documented."""
print("Checking all models are properly tested.")
check_all_decorator_order()
check_all_models_are_tested()