make style (#11442)
This commit is contained in:
committed by
GitHub
parent
04ab2ca639
commit
32dbb2d954
@@ -33,7 +33,7 @@ def _should_continue(line, indent):
|
||||
|
||||
|
||||
def find_code_in_transformers(object_name):
|
||||
""" Find and return the code source code of `object_name`."""
|
||||
"""Find and return the code source code of `object_name`."""
|
||||
parts = object_name.split(".")
|
||||
i = 0
|
||||
|
||||
@@ -193,7 +193,7 @@ def check_copies(overwrite: bool = False):
|
||||
|
||||
|
||||
def get_model_list():
|
||||
""" Extracts the model list from the README. """
|
||||
"""Extracts the model list from the README."""
|
||||
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
|
||||
_start_prompt = "🤗 Transformers currently provides the following architectures"
|
||||
_end_prompt = "1. Want to contribute a new model?"
|
||||
@@ -224,7 +224,7 @@ def get_model_list():
|
||||
|
||||
|
||||
def split_long_line_with_indent(line, max_per_line, indent):
|
||||
""" Split the `line` so that it doesn't go over `max_per_line` and adds `indent` to new lines. """
|
||||
"""Split the `line` so that it doesn't go over `max_per_line` and adds `indent` to new lines."""
|
||||
words = line.split(" ")
|
||||
lines = []
|
||||
current_line = words[0]
|
||||
@@ -239,7 +239,7 @@ def split_long_line_with_indent(line, max_per_line, indent):
|
||||
|
||||
|
||||
def convert_to_rst(model_list, max_per_line=None):
|
||||
""" Convert `model_list` to rst format. """
|
||||
"""Convert `model_list` to rst format."""
|
||||
# Convert **[description](link)** to `description <link>`__
|
||||
def _rep_link(match):
|
||||
title, link = match.groups()
|
||||
@@ -298,7 +298,7 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
|
||||
|
||||
|
||||
def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
""" Check the model lists in the README and index.rst are consistent and maybe `overwrite`. """
|
||||
"""Check the model lists in the README and index.rst are consistent and maybe `overwrite`."""
|
||||
rst_list, start_index, end_index, lines = _find_text_in_file(
|
||||
filename=os.path.join(PATH_TO_DOCS, "index.rst"),
|
||||
start_prompt=" This list is updated automatically from the README",
|
||||
|
||||
@@ -65,7 +65,7 @@ def find_backend(line):
|
||||
|
||||
|
||||
def read_init():
|
||||
""" Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
|
||||
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
|
||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
@@ -101,7 +101,7 @@ def read_init():
|
||||
|
||||
|
||||
def create_dummy_object(name, backend_name):
|
||||
""" Create the code for the dummy object corresponding to `name`."""
|
||||
"""Create the code for the dummy object corresponding to `name`."""
|
||||
_pretrained = [
|
||||
"Config" "ForCausalLM",
|
||||
"ForConditionalGeneration",
|
||||
@@ -130,7 +130,7 @@ def create_dummy_object(name, backend_name):
|
||||
|
||||
|
||||
def create_dummy_files():
|
||||
""" Create the content of the dummy files. """
|
||||
"""Create the content of the dummy files."""
|
||||
backend_specific_objects = read_init()
|
||||
# For special correspondence backend to module name as used in the function requires_modulename
|
||||
dummy_files = {}
|
||||
@@ -146,7 +146,7 @@ def create_dummy_files():
|
||||
|
||||
|
||||
def check_dummies(overwrite=False):
|
||||
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """
|
||||
"""Check if the dummy files are up to date and maybe `overwrite` with the right content."""
|
||||
dummy_files = create_dummy_files()
|
||||
# For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
|
||||
short_names = {"torch": "pt"}
|
||||
|
||||
@@ -119,7 +119,7 @@ transformers = spec.loader.load_module()
|
||||
# If some modeling modules should be ignored for all checks, they should be added in the nested list
|
||||
# _ignore_modules of this function.
|
||||
def get_model_modules():
|
||||
""" Get the model modules inside the transformers library. """
|
||||
"""Get the model modules inside the transformers library."""
|
||||
_ignore_modules = [
|
||||
"modeling_auto",
|
||||
"modeling_encoder_decoder",
|
||||
@@ -151,7 +151,7 @@ def get_model_modules():
|
||||
|
||||
|
||||
def get_models(module):
|
||||
""" Get the objects in module that are models."""
|
||||
"""Get the objects in module that are models."""
|
||||
models = []
|
||||
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
|
||||
for attr_name in dir(module):
|
||||
@@ -166,7 +166,7 @@ def get_models(module):
|
||||
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
|
||||
# nested list _ignore_files of this function.
|
||||
def get_model_test_files():
|
||||
""" Get the model test files."""
|
||||
"""Get the model test files."""
|
||||
_ignore_files = [
|
||||
"test_modeling_common",
|
||||
"test_modeling_encoder_decoder",
|
||||
@@ -187,7 +187,7 @@ def get_model_test_files():
|
||||
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
|
||||
# for the all_model_classes variable.
|
||||
def find_tested_models(test_file):
|
||||
""" Parse the content of test_file to detect what's in all_model_classes"""
|
||||
"""Parse the content of test_file to detect what's in all_model_classes"""
|
||||
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
|
||||
with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
|
||||
content = f.read()
|
||||
@@ -205,7 +205,7 @@ def find_tested_models(test_file):
|
||||
|
||||
|
||||
def check_models_are_tested(module, test_file):
|
||||
""" Check models defined in module are tested in test_file."""
|
||||
"""Check models defined in module are tested in test_file."""
|
||||
defined_models = get_models(module)
|
||||
tested_models = find_tested_models(test_file)
|
||||
if tested_models is None:
|
||||
@@ -229,7 +229,7 @@ def check_models_are_tested(module, test_file):
|
||||
|
||||
|
||||
def check_all_models_are_tested():
|
||||
""" Check all models are properly tested."""
|
||||
"""Check all models are properly tested."""
|
||||
modules = get_model_modules()
|
||||
test_files = get_model_test_files()
|
||||
failures = []
|
||||
@@ -245,7 +245,7 @@ def check_all_models_are_tested():
|
||||
|
||||
|
||||
def get_all_auto_configured_models():
|
||||
""" Return the list of all models in at least one auto class."""
|
||||
"""Return the list of all models in at least one auto class."""
|
||||
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
||||
for attr_name in dir(transformers.models.auto.modeling_auto):
|
||||
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
|
||||
@@ -271,7 +271,7 @@ def ignore_unautoclassed(model_name):
|
||||
|
||||
|
||||
def check_models_are_auto_configured(module, all_auto_models):
|
||||
""" Check models defined in module are each in an auto class."""
|
||||
"""Check models defined in module are each in an auto class."""
|
||||
defined_models = get_models(module)
|
||||
failures = []
|
||||
for model_name, _ in defined_models:
|
||||
@@ -285,7 +285,7 @@ def check_models_are_auto_configured(module, all_auto_models):
|
||||
|
||||
|
||||
def check_all_models_are_auto_configured():
|
||||
""" Check all models are each in an auto class."""
|
||||
"""Check all models are each in an auto class."""
|
||||
modules = get_model_modules()
|
||||
all_auto_models = get_all_auto_configured_models()
|
||||
failures = []
|
||||
@@ -301,7 +301,7 @@ _re_decorator = re.compile(r"^\s*@(\S+)\s+$")
|
||||
|
||||
|
||||
def check_decorator_order(filename):
|
||||
""" Check that in the test file `filename` the slow decorator is always last."""
|
||||
"""Check that in the test file `filename` the slow decorator is always last."""
|
||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
decorator_before = None
|
||||
@@ -319,7 +319,7 @@ def check_decorator_order(filename):
|
||||
|
||||
|
||||
def check_all_decorator_order():
|
||||
""" Check that in all test files, the slow decorator is always last."""
|
||||
"""Check that in all test files, the slow decorator is always last."""
|
||||
errors = []
|
||||
for fname in os.listdir(PATH_TO_TESTS):
|
||||
if fname.endswith(".py"):
|
||||
@@ -334,7 +334,7 @@ def check_all_decorator_order():
|
||||
|
||||
|
||||
def find_all_documented_objects():
|
||||
""" Parse the content of all doc files to detect which classes and functions it documents"""
|
||||
"""Parse the content of all doc files to detect which classes and functions it documents"""
|
||||
documented_obj = []
|
||||
for doc_file in Path(PATH_TO_DOC).glob("**/*.rst"):
|
||||
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
@@ -454,7 +454,7 @@ def ignore_undocumented(name):
|
||||
|
||||
|
||||
def check_all_objects_are_documented():
|
||||
""" Check all models are properly documented."""
|
||||
"""Check all models are properly documented."""
|
||||
documented_objs = find_all_documented_objects()
|
||||
modules = transformers._modules
|
||||
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
|
||||
@@ -467,7 +467,7 @@ def check_all_objects_are_documented():
|
||||
|
||||
|
||||
def check_repo_quality():
|
||||
""" Check all models are properly tested and documented."""
|
||||
"""Check all models are properly tested and documented."""
|
||||
print("Checking all models are properly tested.")
|
||||
check_all_decorator_order()
|
||||
check_all_models_are_tested()
|
||||
|
||||
@@ -159,7 +159,7 @@ def get_model_table_from_auto_modules():
|
||||
|
||||
|
||||
def check_model_table(overwrite=False):
|
||||
""" Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`. """
|
||||
"""Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`."""
|
||||
current_table, start_index, end_index, lines = _find_text_in_file(
|
||||
filename=os.path.join(PATH_TO_DOCS, "index.rst"),
|
||||
start_prompt=" This table is updated automatically from the auto module",
|
||||
|
||||
@@ -431,7 +431,7 @@ def _add_new_lines_before_doc_special_words(text):
|
||||
|
||||
|
||||
def style_rst_file(doc_file, max_len=119, check_only=False):
|
||||
""" Style one rst file `doc_file` to `max_len`."""
|
||||
"""Style one rst file `doc_file` to `max_len`."""
|
||||
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
doc = f.read()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user