Check list of models in the main README and sort it (#17517)

* Script for README

* Fix copies

* Complete error message
This commit is contained in:
Sylvain Gugger
2022-06-02 08:10:08 -04:00
committed by GitHub
parent 588d8f1f26
commit 048dd73bba
8 changed files with 164 additions and 69 deletions

View File

@@ -15,6 +15,7 @@
import argparse
import glob
import importlib.util
import os
import re
@@ -72,6 +73,15 @@ LOCALIZED_READMES = {
}
# This is to make sure the transformers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location(
"transformers",
os.path.join(TRANSFORMERS_PATH, "__init__.py"),
submodule_search_locations=[TRANSFORMERS_PATH],
)
transformers_module = spec.loader.load_module()
def _should_continue(line, indent):
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
@@ -449,10 +459,88 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
)
SPECIAL_MODEL_NAMES = {
"Bert Generation": "BERT For Sequence Generation",
"BigBird": "BigBird-RoBERTa",
"Data2VecAudio": "Data2Vec",
"Data2VecText": "Data2Vec",
"Data2VecVision": "Data2Vec",
"Marian": "MarianMT",
"OpenAI GPT-2": "GPT-2",
"OpenAI GPT": "GPT",
"Perceiver": "Perceiver IO",
"ViT": "Vision Transformer (ViT)",
}
# Update this list with the models that shouldn't be in the README. This only concerns modular models or those who do
# not have an associated paper.
MODELS_NOT_IN_README = [
"BertJapanese",
"Encoder decoder",
"FairSeq Machine-Translation",
"HerBERT",
"RetriBERT",
"Speech Encoder decoder",
"Speech2Text",
"Speech2Text2",
"Vision Encoder decoder",
"VisionTextDualEncoder",
"Wav2Vec2-Conformer",
]
README_TEMPLATE = (
"1. **[{model_name}](https://huggingface.co/docs/transformers/model_doc/{model_type})** (from <FILL INSTITUTION>) "
"released with the paper [<FILL PAPER TITLE>](<FILL ARKIV LINK>) by <FILL AUTHORS>."
)
def check_readme(overwrite=False):
info = LOCALIZED_READMES["README.md"]
models, start_index, end_index, lines = _find_text_in_file(
os.path.join(REPO_PATH, "README.md"),
info["start_prompt"],
info["end_prompt"],
)
models_in_readme = [re.search(r"\*\*\[([^\]]*)", line).groups()[0] for line in models.strip().split("\n")]
model_names_mapping = transformers_module.models.auto.configuration_auto.MODEL_NAMES_MAPPING
absents = [
(key, name)
for key, name in model_names_mapping.items()
if SPECIAL_MODEL_NAMES.get(name, name) not in models_in_readme
]
# Remove exceptions
absents = [(key, name) for key, name in absents if name not in MODELS_NOT_IN_README]
if len(absents) > 0 and not overwrite:
print(absents)
raise ValueError(
"The main README doesn't contain all models, run `make fix-copies` to fill it with the missing model(s)"
" then complete the generated entries.\nIf the model is not supposed to be in the main README, add it to"
" the list `MODELS_NOT_IN_README` in utils/check_copies.py.\nIf it has a different name in the repo than"
" in the README, map the correspondence in `SPECIAL_MODEL_NAMES` in utils/check_copies.py."
)
new_models = [README_TEMPLATE.format(model_name=name, model_type=key) for key, name in absents]
all_models = models.strip().split("\n") + new_models
all_models = sorted(all_models, key=lambda x: re.search(r"\*\*\[([^\]]*)", x).groups()[0].lower())
all_models = "\n".join(all_models) + "\n"
if all_models != models:
if overwrite:
print("Fixing the main README.")
with open(os.path.join(REPO_PATH, "README.md"), "w", encoding="utf-8", newline="\n") as f:
f.writelines(lines[:start_index] + [all_models] + lines[end_index:])
else:
raise ValueError("The main README model list is not properly sorted. Run `make fix-copies` to fix this.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
check_readme(args.fix_and_overwrite)
check_copies(args.fix_and_overwrite)
check_full_copies(args.fix_and_overwrite)