Check and update model list in index.rst automatically (#7527)
* Check and update model list in index.rst automatically * Check and update model list in index.rst automatically * Adapt template
This commit is contained in:
@@ -23,6 +23,7 @@ import tempfile
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_copies.py
|
||||
TRANSFORMERS_PATH = "src/transformers"
|
||||
PATH_TO_DOCS = "docs/source"
|
||||
|
||||
|
||||
def find_code_in_transformers(object_name):
|
||||
@@ -166,6 +167,113 @@ def check_copies(overwrite: bool = False):
|
||||
+ diff
|
||||
+ "\nRun `make fix-copies` or `python utils/check_copies --fix_and_overwrite` to fix them."
|
||||
)
|
||||
check_model_list_copy(overwrite=overwrite)
|
||||
|
||||
|
||||
def get_model_list():
|
||||
""" Extracts the model list from the README. """
|
||||
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
|
||||
_start_prompt = "🤗 Transformers currently provides the following architectures"
|
||||
_end_prompt = "1. Want to contribute a new model?"
|
||||
with open(os.path.join("README.md"), "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
# Find the start of the list.
|
||||
start_index = 0
|
||||
while not lines[start_index].startswith(_start_prompt):
|
||||
start_index += 1
|
||||
start_index += 1
|
||||
|
||||
result = []
|
||||
current_line = ""
|
||||
end_index = start_index
|
||||
|
||||
while not lines[end_index].startswith(_end_prompt):
|
||||
if lines[end_index].startswith("1."):
|
||||
if len(current_line) > 1:
|
||||
result.append(current_line)
|
||||
current_line = lines[end_index]
|
||||
elif len(lines[end_index]) > 1:
|
||||
current_line = f"{current_line[:-1]} {lines[end_index].lstrip()}"
|
||||
end_index += 1
|
||||
if len(current_line) > 1:
|
||||
result.append(current_line)
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def split_long_line_with_indent(line, max_per_line, indent):
|
||||
""" Split the `line` so that it doesn't go over `max_per_line` and adds `indent` to new lines. """
|
||||
words = line.split(" ")
|
||||
lines = []
|
||||
current_line = words[0]
|
||||
for word in words[1:]:
|
||||
if len(f"{current_line} {word}") > max_per_line:
|
||||
lines.append(current_line)
|
||||
current_line = " " * indent + word
|
||||
else:
|
||||
current_line = f"{current_line} {word}"
|
||||
lines.append(current_line)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def convert_to_rst(model_list, max_per_line=None):
|
||||
""" Convert `model_list` to rst format. """
|
||||
# Convert **[description](link)** to `description <link>`__
|
||||
model_list = re.sub(r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\*", r"`\1 <\2>`__", model_list)
|
||||
|
||||
# Convert [description](link) to `description <link>`__
|
||||
model_list = re.sub(r"\[([^\]]*)\]\(([^\)]*)\)", r"`\1 <\2>`__", model_list)
|
||||
|
||||
# Enumerate the lines properly
|
||||
lines = model_list.split("\n")
|
||||
result = []
|
||||
for i, line in enumerate(lines):
|
||||
line = re.sub(r"^\s*(\d+)\.", f"{i+1}.", line)
|
||||
# Split the lines that are too long
|
||||
if max_per_line is not None and len(line) > max_per_line:
|
||||
prompt = re.search(r"^(\s*\d+\.\s+)\S", line)
|
||||
indent = len(prompt.groups()[0]) if prompt is not None else 0
|
||||
line = split_long_line_with_indent(line, max_per_line, indent)
|
||||
|
||||
result.append(line)
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
""" Check the model lists in the README and index.rst are consistent and maybe `overwrite`. """
|
||||
_start_prompt = " This list is updated automatically from the README"
|
||||
_end_prompt = ".. toctree::"
|
||||
with open(os.path.join(PATH_TO_DOCS, "index.rst"), "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
# Find the start of the list.
|
||||
start_index = 0
|
||||
while not lines[start_index].startswith(_start_prompt):
|
||||
start_index += 1
|
||||
start_index += 1
|
||||
|
||||
end_index = start_index
|
||||
while not lines[end_index].startswith(_end_prompt):
|
||||
end_index += 1
|
||||
end_index -= 1
|
||||
|
||||
while len(lines[start_index]) <= 1:
|
||||
start_index += 1
|
||||
while len(lines[end_index]) <= 1:
|
||||
end_index -= 1
|
||||
end_index += 1
|
||||
|
||||
rst_list = "".join(lines[start_index:end_index])
|
||||
md_list = get_model_list()
|
||||
converted_list = convert_to_rst(md_list, max_per_line=max_per_line)
|
||||
|
||||
if converted_list != rst_list:
|
||||
if overwrite:
|
||||
with open(os.path.join(PATH_TO_DOCS, "index.rst"), "w", encoding="utf-8") as f:
|
||||
f.writelines(lines[:start_index] + [converted_list] + lines[end_index:])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The model list in the README changed and the list in `index.rst` has not been updated. Run `make fix-copies` to fix this."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user