Check table as independent script (#8976)
This commit is contained in:
@@ -14,9 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import glob
|
||||
import importlib.util
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
@@ -299,134 +297,9 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
)
|
||||
|
||||
|
||||
# Add here suffixes that are used to identify models, seperated by |
|
||||
ALLOWED_MODEL_SUFFIXES = "Model|Encoder|Decoder|ForConditionalGeneration"
|
||||
# Regexes that match TF/Flax/PT model names.
|
||||
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||
_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||
# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes.
|
||||
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
|
||||
|
||||
|
||||
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
|
||||
def camel_case_split(identifier):
|
||||
"Split a camelcased `identifier` into words."
|
||||
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
|
||||
return [m.group(0) for m in matches]
|
||||
|
||||
|
||||
def _center_text(text, width):
|
||||
text_length = 2 if text == "✅" or text == "❌" else len(text)
|
||||
left_indent = (width - text_length) // 2
|
||||
right_indent = width - text_length - left_indent
|
||||
return " " * left_indent + text + " " * right_indent
|
||||
|
||||
|
||||
def get_model_table_from_auto_modules():
|
||||
"""Generates an up-to-date model table from the content of the auto modules."""
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"transformers",
|
||||
os.path.join(TRANSFORMERS_PATH, "__init__.py"),
|
||||
submodule_search_locations=[TRANSFORMERS_PATH],
|
||||
)
|
||||
transformers = spec.loader.load_module()
|
||||
|
||||
# Dictionary model names to config.
|
||||
model_name_to_config = {
|
||||
name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items()
|
||||
}
|
||||
model_name_to_prefix = {
|
||||
name: config.__name__.replace("Config", "") for name, config in model_name_to_config.items()
|
||||
}
|
||||
|
||||
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
|
||||
slow_tokenizers = collections.defaultdict(bool)
|
||||
fast_tokenizers = collections.defaultdict(bool)
|
||||
pt_models = collections.defaultdict(bool)
|
||||
tf_models = collections.defaultdict(bool)
|
||||
flax_models = collections.defaultdict(bool)
|
||||
|
||||
# Let's lookup through all transformers object (once).
|
||||
for attr_name in dir(transformers):
|
||||
lookup_dict = None
|
||||
if attr_name.endswith("Tokenizer"):
|
||||
lookup_dict = slow_tokenizers
|
||||
attr_name = attr_name[:-9]
|
||||
elif attr_name.endswith("TokenizerFast"):
|
||||
lookup_dict = fast_tokenizers
|
||||
attr_name = attr_name[:-13]
|
||||
elif _re_tf_models.match(attr_name) is not None:
|
||||
lookup_dict = tf_models
|
||||
attr_name = _re_tf_models.match(attr_name).groups()[0]
|
||||
elif _re_flax_models.match(attr_name) is not None:
|
||||
lookup_dict = flax_models
|
||||
attr_name = _re_flax_models.match(attr_name).groups()[0]
|
||||
elif _re_pt_models.match(attr_name) is not None:
|
||||
lookup_dict = pt_models
|
||||
attr_name = _re_pt_models.match(attr_name).groups()[0]
|
||||
|
||||
if lookup_dict is not None:
|
||||
while len(attr_name) > 0:
|
||||
if attr_name in model_name_to_prefix.values():
|
||||
lookup_dict[attr_name] = True
|
||||
break
|
||||
# Try again after removing the last word in the name
|
||||
attr_name = "".join(camel_case_split(attr_name)[:-1])
|
||||
|
||||
# Let's build that table!
|
||||
model_names = list(model_name_to_config.keys())
|
||||
model_names.sort()
|
||||
columns = ["Model", "Tokenizer slow", "Tokenizer fast", "PyTorch support", "TensorFlow support", "Flax Support"]
|
||||
# We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side).
|
||||
widths = [len(c) + 2 for c in columns]
|
||||
widths[0] = max([len(name) for name in model_names]) + 2
|
||||
|
||||
# Rst table per se
|
||||
table = ".. rst-class:: center-aligned-table\n\n"
|
||||
table += "+" + "+".join(["-" * w for w in widths]) + "+\n"
|
||||
table += "|" + "|".join([_center_text(c, w) for c, w in zip(columns, widths)]) + "|\n"
|
||||
table += "+" + "+".join(["=" * w for w in widths]) + "+\n"
|
||||
|
||||
check = {True: "✅", False: "❌"}
|
||||
for name in model_names:
|
||||
prefix = model_name_to_prefix[name]
|
||||
line = [
|
||||
name,
|
||||
check[slow_tokenizers[prefix]],
|
||||
check[fast_tokenizers[prefix]],
|
||||
check[pt_models[prefix]],
|
||||
check[tf_models[prefix]],
|
||||
check[flax_models[prefix]],
|
||||
]
|
||||
table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n"
|
||||
table += "+" + "+".join(["-" * w for w in widths]) + "+\n"
|
||||
return table
|
||||
|
||||
|
||||
def check_model_table(overwrite=False):
|
||||
""" Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`. """
|
||||
current_table, start_index, end_index, lines = _find_text_in_file(
|
||||
filename=os.path.join(PATH_TO_DOCS, "index.rst"),
|
||||
start_prompt=" This table is updated automatically from the auto module",
|
||||
end_prompt=".. toctree::",
|
||||
)
|
||||
new_table = get_model_table_from_auto_modules()
|
||||
|
||||
if current_table != new_table:
|
||||
if overwrite:
|
||||
with open(os.path.join(PATH_TO_DOCS, "index.rst"), "w", encoding="utf-8", newline="\n") as f:
|
||||
f.writelines(lines[:start_index] + [new_table] + lines[end_index:])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The model table in the `index.rst` has not been updated. Run `make fix-copies` to fix this."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
|
||||
check_copies(args.fix_and_overwrite)
|
||||
check_model_table(args.fix_and_overwrite)
|
||||
|
||||
Reference in New Issue
Block a user