Big model table (#8774)

* First draft

* Styling

* With all changes staged

* Update docs/source/index.rst

Co-authored-by: Julien Chaumond <chaumond@gmail.com>

* Styling

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Sylvain Gugger
2020-11-25 12:02:15 -05:00
committed by GitHub
parent 90d5ab3bfe
commit 4821ea5aeb
8 changed files with 257 additions and 19 deletions

View File

@@ -15,6 +15,7 @@
import argparse
import glob
import importlib
import os
import re
import tempfile
@@ -250,20 +251,21 @@ def convert_to_rst(model_list, max_per_line=None):
return "\n".join(result)
def check_model_list_copy(overwrite=False, max_per_line=119):
""" Check the model lists in the README and index.rst are consistent and maybe `overwrite`. """
_start_prompt = " This list is updated automatically from the README"
_end_prompt = ".. toctree::"
with open(os.path.join(PATH_TO_DOCS, "index.rst"), "r", encoding="utf-8", newline="\n") as f:
def _find_text_in_file(filename, start_prompt, end_prompt):
"""
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
lines.
"""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Find the start of the list.
# Find the start prompt.
start_index = 0
while not lines[start_index].startswith(_start_prompt):
while not lines[start_index].startswith(start_prompt):
start_index += 1
start_index += 1
end_index = start_index
while not lines[end_index].startswith(_end_prompt):
while not lines[end_index].startswith(end_prompt):
end_index += 1
end_index -= 1
@@ -272,8 +274,16 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
while len(lines[end_index]) <= 1:
end_index -= 1
end_index += 1
return "".join(lines[start_index:end_index]), start_index, end_index, lines
rst_list = "".join(lines[start_index:end_index])
def check_model_list_copy(overwrite=False, max_per_line=119):
""" Check the model lists in the README and index.rst are consistent and maybe `overwrite`. """
rst_list, start_index, end_index, lines = _find_text_in_file(
filename=os.path.join(PATH_TO_DOCS, "index.rst"),
start_prompt=" This list is updated automatically from the README",
end_prompt="The table below represents the current support",
)
md_list = get_model_list()
converted_list = convert_to_rst(md_list, max_per_line=max_per_line)
@@ -283,7 +293,116 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
f.writelines(lines[:start_index] + [converted_list] + lines[end_index:])
else:
raise ValueError(
"The model list in the README changed and the list in `index.rst` has not been updated. Run `make fix-copies` to fix this."
"The model list in the README changed and the list in `index.rst` has not been updated. Run "
"`make fix-copies` to fix this."
)
def _center_text(text, width):
text_length = 2 if text == "" or text == "" else len(text)
left_indent = (width - text_length) // 2
right_indent = width - text_length - left_indent
return " " * left_indent + text + " " * right_indent
def get_model_table_from_auto_modules():
"""Generates an up-to-date model table from the content of the auto modules."""
# This is to make sure the transformers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location(
"transformers",
os.path.join(TRANSFORMERS_PATH, "__init__.py"),
submodule_search_locations=[TRANSFORMERS_PATH],
)
transformers = spec.loader.load_module()
# Dictionary model names to config.
model_name_to_config = {
name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items()
}
# All tokenizer tuples.
tokenizers = {
name: transformers.TOKENIZER_MAPPING[config]
for name, config in model_name_to_config.items()
if config in transformers.TOKENIZER_MAPPING
}
# Model names that a slow/fast tokenizer.
has_slow_tokenizers = [name for name, tok in tokenizers.items() if tok[0] is not None]
has_fast_tokenizers = [name for name, tok in tokenizers.items() if tok[1] is not None]
# Model names that have a PyTorch implementation.
has_pt_model = [name for name, config in model_name_to_config.items() if config in transformers.MODEL_MAPPING]
# Some of the GenerationModel don't have a base model.
has_pt_model.extend(
[
name
for name, config in model_name_to_config.items()
if config in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
]
)
# Special exception for RAG
has_pt_model.append("RAG")
# Model names that have a TensorFlow implementation.
has_tf_model = [name for name, config in model_name_to_config.items() if config in transformers.TF_MODEL_MAPPING]
# Some of the GenerationModel don't have a base model.
has_tf_model.extend(
[
name
for name, config in model_name_to_config.items()
if config in transformers.TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
]
)
# Model names that have a Flax implementation.
has_flax_model = [
name for name, config in model_name_to_config.items() if config in transformers.FLAX_MODEL_MAPPING
]
# Let's build that table!
model_names = list(model_name_to_config.keys())
model_names.sort()
columns = ["Model", "Tokenizer slow", "Tokenizer fast", "PyTorch support", "TensorFlow support", "Flax Support"]
# We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side).
widths = [len(c) + 2 for c in columns]
widths[0] = max([len(name) for name in model_names]) + 2
# Rst table per se
table = ".. rst-class:: center-aligned-table\n\n"
table += "+" + "+".join(["-" * w for w in widths]) + "+\n"
table += "|" + "|".join([_center_text(c, w) for c, w in zip(columns, widths)]) + "|\n"
table += "+" + "+".join(["=" * w for w in widths]) + "+\n"
check = {True: "", False: ""}
for name in model_names:
line = [
name,
check[name in has_slow_tokenizers],
check[name in has_fast_tokenizers],
check[name in has_pt_model],
check[name in has_tf_model],
check[name in has_flax_model],
]
table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n"
table += "+" + "+".join(["-" * w for w in widths]) + "+\n"
return table
def check_model_table(overwrite=False):
""" Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`. """
current_table, start_index, end_index, lines = _find_text_in_file(
filename=os.path.join(PATH_TO_DOCS, "index.rst"),
start_prompt=" This table is updated automatically from the auto module",
end_prompt=".. toctree::",
)
new_table = get_model_table_from_auto_modules()
if current_table != new_table:
if overwrite:
with open(os.path.join(PATH_TO_DOCS, "index.rst"), "w", encoding="utf-8", newline="\n") as f:
f.writelines(lines[:start_index] + [new_table] + lines[end_index:])
else:
raise ValueError(
"The model table in the `index.rst` has not been updated. Run `make fix-copies` to fix this."
)
@@ -293,3 +412,4 @@ if __name__ == "__main__":
args = parser.parse_args()
check_copies(args.fix_and_overwrite)
check_model_table(args.fix_and_overwrite)

View File

@@ -126,6 +126,7 @@ def get_model_modules():
"modeling_outputs",
"modeling_retribert",
"modeling_utils",
"modeling_flax_auto",
"modeling_flax_utils",
"modeling_transfo_xl_utilities",
"modeling_tf_auto",