Big model table (#8774)
* First draft * Styling * With all changes staged * Update docs/source/index.rst Co-authored-by: Julien Chaumond <chaumond@gmail.com> * Styling Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
@@ -250,20 +251,21 @@ def convert_to_rst(model_list, max_per_line=None):
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
""" Check the model lists in the README and index.rst are consistent and maybe `overwrite`. """
|
||||
_start_prompt = " This list is updated automatically from the README"
|
||||
_end_prompt = ".. toctree::"
|
||||
with open(os.path.join(PATH_TO_DOCS, "index.rst"), "r", encoding="utf-8", newline="\n") as f:
|
||||
def _find_text_in_file(filename, start_prompt, end_prompt):
|
||||
"""
|
||||
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
|
||||
lines.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
# Find the start of the list.
|
||||
# Find the start prompt.
|
||||
start_index = 0
|
||||
while not lines[start_index].startswith(_start_prompt):
|
||||
while not lines[start_index].startswith(start_prompt):
|
||||
start_index += 1
|
||||
start_index += 1
|
||||
|
||||
end_index = start_index
|
||||
while not lines[end_index].startswith(_end_prompt):
|
||||
while not lines[end_index].startswith(end_prompt):
|
||||
end_index += 1
|
||||
end_index -= 1
|
||||
|
||||
@@ -272,8 +274,16 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
while len(lines[end_index]) <= 1:
|
||||
end_index -= 1
|
||||
end_index += 1
|
||||
return "".join(lines[start_index:end_index]), start_index, end_index, lines
|
||||
|
||||
rst_list = "".join(lines[start_index:end_index])
|
||||
|
||||
def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
""" Check the model lists in the README and index.rst are consistent and maybe `overwrite`. """
|
||||
rst_list, start_index, end_index, lines = _find_text_in_file(
|
||||
filename=os.path.join(PATH_TO_DOCS, "index.rst"),
|
||||
start_prompt=" This list is updated automatically from the README",
|
||||
end_prompt="The table below represents the current support",
|
||||
)
|
||||
md_list = get_model_list()
|
||||
converted_list = convert_to_rst(md_list, max_per_line=max_per_line)
|
||||
|
||||
@@ -283,7 +293,116 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
f.writelines(lines[:start_index] + [converted_list] + lines[end_index:])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The model list in the README changed and the list in `index.rst` has not been updated. Run `make fix-copies` to fix this."
|
||||
"The model list in the README changed and the list in `index.rst` has not been updated. Run "
|
||||
"`make fix-copies` to fix this."
|
||||
)
|
||||
|
||||
|
||||
def _center_text(text, width):
|
||||
text_length = 2 if text == "✅" or text == "❌" else len(text)
|
||||
left_indent = (width - text_length) // 2
|
||||
right_indent = width - text_length - left_indent
|
||||
return " " * left_indent + text + " " * right_indent
|
||||
|
||||
|
||||
def get_model_table_from_auto_modules():
|
||||
"""Generates an up-to-date model table from the content of the auto modules."""
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"transformers",
|
||||
os.path.join(TRANSFORMERS_PATH, "__init__.py"),
|
||||
submodule_search_locations=[TRANSFORMERS_PATH],
|
||||
)
|
||||
transformers = spec.loader.load_module()
|
||||
|
||||
# Dictionary model names to config.
|
||||
model_name_to_config = {
|
||||
name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items()
|
||||
}
|
||||
# All tokenizer tuples.
|
||||
tokenizers = {
|
||||
name: transformers.TOKENIZER_MAPPING[config]
|
||||
for name, config in model_name_to_config.items()
|
||||
if config in transformers.TOKENIZER_MAPPING
|
||||
}
|
||||
# Model names that a slow/fast tokenizer.
|
||||
has_slow_tokenizers = [name for name, tok in tokenizers.items() if tok[0] is not None]
|
||||
has_fast_tokenizers = [name for name, tok in tokenizers.items() if tok[1] is not None]
|
||||
|
||||
# Model names that have a PyTorch implementation.
|
||||
has_pt_model = [name for name, config in model_name_to_config.items() if config in transformers.MODEL_MAPPING]
|
||||
# Some of the GenerationModel don't have a base model.
|
||||
has_pt_model.extend(
|
||||
[
|
||||
name
|
||||
for name, config in model_name_to_config.items()
|
||||
if config in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
]
|
||||
)
|
||||
# Special exception for RAG
|
||||
has_pt_model.append("RAG")
|
||||
|
||||
# Model names that have a TensorFlow implementation.
|
||||
has_tf_model = [name for name, config in model_name_to_config.items() if config in transformers.TF_MODEL_MAPPING]
|
||||
# Some of the GenerationModel don't have a base model.
|
||||
has_tf_model.extend(
|
||||
[
|
||||
name
|
||||
for name, config in model_name_to_config.items()
|
||||
if config in transformers.TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
]
|
||||
)
|
||||
|
||||
# Model names that have a Flax implementation.
|
||||
has_flax_model = [
|
||||
name for name, config in model_name_to_config.items() if config in transformers.FLAX_MODEL_MAPPING
|
||||
]
|
||||
|
||||
# Let's build that table!
|
||||
model_names = list(model_name_to_config.keys())
|
||||
model_names.sort()
|
||||
columns = ["Model", "Tokenizer slow", "Tokenizer fast", "PyTorch support", "TensorFlow support", "Flax Support"]
|
||||
# We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side).
|
||||
widths = [len(c) + 2 for c in columns]
|
||||
widths[0] = max([len(name) for name in model_names]) + 2
|
||||
|
||||
# Rst table per se
|
||||
table = ".. rst-class:: center-aligned-table\n\n"
|
||||
table += "+" + "+".join(["-" * w for w in widths]) + "+\n"
|
||||
table += "|" + "|".join([_center_text(c, w) for c, w in zip(columns, widths)]) + "|\n"
|
||||
table += "+" + "+".join(["=" * w for w in widths]) + "+\n"
|
||||
|
||||
check = {True: "✅", False: "❌"}
|
||||
for name in model_names:
|
||||
line = [
|
||||
name,
|
||||
check[name in has_slow_tokenizers],
|
||||
check[name in has_fast_tokenizers],
|
||||
check[name in has_pt_model],
|
||||
check[name in has_tf_model],
|
||||
check[name in has_flax_model],
|
||||
]
|
||||
table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n"
|
||||
table += "+" + "+".join(["-" * w for w in widths]) + "+\n"
|
||||
return table
|
||||
|
||||
|
||||
def check_model_table(overwrite=False):
|
||||
""" Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`. """
|
||||
current_table, start_index, end_index, lines = _find_text_in_file(
|
||||
filename=os.path.join(PATH_TO_DOCS, "index.rst"),
|
||||
start_prompt=" This table is updated automatically from the auto module",
|
||||
end_prompt=".. toctree::",
|
||||
)
|
||||
new_table = get_model_table_from_auto_modules()
|
||||
|
||||
if current_table != new_table:
|
||||
if overwrite:
|
||||
with open(os.path.join(PATH_TO_DOCS, "index.rst"), "w", encoding="utf-8", newline="\n") as f:
|
||||
f.writelines(lines[:start_index] + [new_table] + lines[end_index:])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The model table in the `index.rst` has not been updated. Run `make fix-copies` to fix this."
|
||||
)
|
||||
|
||||
|
||||
@@ -293,3 +412,4 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
check_copies(args.fix_and_overwrite)
|
||||
check_model_table(args.fix_and_overwrite)
|
||||
|
||||
@@ -126,6 +126,7 @@ def get_model_modules():
|
||||
"modeling_outputs",
|
||||
"modeling_retribert",
|
||||
"modeling_utils",
|
||||
"modeling_flax_auto",
|
||||
"modeling_flax_utils",
|
||||
"modeling_transfo_xl_utilities",
|
||||
"modeling_tf_auto",
|
||||
|
||||
Reference in New Issue
Block a user