Document check copies (#25291)
* Document check copies better and add tests * Include header in check for copies * Manual fixes * Try autofix * Fixes * Clean tests * Finalize doc * Remove debug print * More fixes
This commit is contained in:
@@ -12,6 +12,29 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utility that checks whether the copies defined in the library match the original or not. This includes:
|
||||
- All code commented with `# Copied from` comments,
|
||||
- The list of models in the main README.md matches the ones in the localized READMEs and in the index.md,
|
||||
- Files that are registered as full copies of one another in the `FULL_COPIES` constant of this script.
|
||||
|
||||
This also checks the list of models in the README is complete (has all models) and add a line to complete if there is
|
||||
a model missing.
|
||||
|
||||
Use from the root of the repo with:
|
||||
|
||||
```bash
|
||||
python utils/check_copies.py
|
||||
```
|
||||
|
||||
for a check that will error in case of inconsistencies (used by `make repo-consistency`) or
|
||||
|
||||
```bash
|
||||
python utils/check_copies.py --fix_and_overwrite
|
||||
```
|
||||
|
||||
for a check that will fix all inconsistencies automatically (used by `make fix-copies`).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
@@ -103,7 +126,9 @@ transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
|
||||
|
||||
|
||||
def _should_continue(line, indent):
|
||||
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
|
||||
# Helper function. Returns `True` if `line` is empty, starts with the `indent` or is the end parenthesis of a
|
||||
# function definition
|
||||
return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
|
||||
|
||||
|
||||
def find_code_in_transformers(object_name):
|
||||
@@ -140,7 +165,7 @@ def find_code_in_transformers(object_name):
|
||||
raise ValueError(f" {object_name} does not match any function or class in {module}.")
|
||||
|
||||
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
|
||||
start_index = line_index
|
||||
start_index = line_index - 1
|
||||
while line_index < len(lines) and _should_continue(lines[line_index], indent):
|
||||
line_index += 1
|
||||
# Clean up empty lines at the end (if any).
|
||||
@@ -179,6 +204,33 @@ def blackify(code):
|
||||
return result[len("class Bla:\n") :] if has_indent else result
|
||||
|
||||
|
||||
def check_codes_match(observed_code, theoretical_code):
|
||||
"""
|
||||
Checks if the code in `observed_code` and `theoretical_code` match with the exception of the class/function name.
|
||||
Returns the index of the first line where there is a difference (if any) and `None` if the codes match.
|
||||
"""
|
||||
observed_code_header = observed_code.split("\n")[0]
|
||||
theoretical_code_header = theoretical_code.split("\n")[0]
|
||||
|
||||
_re_class_match = re.compile(r"class\s+([^\(:]+)(?:\(|:)")
|
||||
_re_func_match = re.compile(r"def\s+([^\(]+)\(")
|
||||
for re_pattern in [_re_class_match, _re_func_match]:
|
||||
if re_pattern.match(observed_code_header) is not None:
|
||||
observed_obj_name = re_pattern.search(observed_code_header).groups()[0]
|
||||
theoretical_name = re_pattern.search(theoretical_code_header).groups()[0]
|
||||
theoretical_code_header = theoretical_code_header.replace(theoretical_name, observed_obj_name)
|
||||
|
||||
diff_index = 0
|
||||
if theoretical_code_header != observed_code_header:
|
||||
return 0
|
||||
|
||||
diff_index = 1
|
||||
for observed_line, theoretical_line in zip(observed_code.split("\n")[1:], theoretical_code.split("\n")[1:]):
|
||||
if observed_line != theoretical_line:
|
||||
return diff_index
|
||||
diff_index += 1
|
||||
|
||||
|
||||
def is_copy_consistent(filename, overwrite=False):
|
||||
"""
|
||||
Check if the code commented as a copy in `filename` matches the original.
|
||||
@@ -201,10 +253,11 @@ def is_copy_consistent(filename, overwrite=False):
|
||||
theoretical_code = find_code_in_transformers(object_name)
|
||||
theoretical_indent = get_indent(theoretical_code)
|
||||
|
||||
start_index = line_index + 1 if indent == theoretical_indent else line_index + 2
|
||||
indent = theoretical_indent
|
||||
line_index = start_index
|
||||
start_index = line_index + 1 if indent == theoretical_indent else line_index
|
||||
line_index = start_index + 1
|
||||
|
||||
subcode = "\n".join(theoretical_code.split("\n")[1:])
|
||||
indent = get_indent(subcode)
|
||||
# Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
|
||||
should_continue = True
|
||||
while line_index < len(lines) and should_continue:
|
||||
@@ -212,6 +265,8 @@ def is_copy_consistent(filename, overwrite=False):
|
||||
if line_index >= len(lines):
|
||||
break
|
||||
line = lines[line_index]
|
||||
# There is a special pattern `# End copy` to stop early. It's not documented cause it shouldn't really be
|
||||
# used.
|
||||
should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
|
||||
# Clean up empty lines at the end (if any).
|
||||
while len(lines[line_index - 1]) <= 1:
|
||||
@@ -233,19 +288,12 @@ def is_copy_consistent(filename, overwrite=False):
|
||||
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
|
||||
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
|
||||
|
||||
# Blackify after replacement. To be able to do that, we need the header (class or function definition)
|
||||
# from the previous line
|
||||
theoretical_code = blackify(lines[start_index - 1] + theoretical_code)
|
||||
theoretical_code = theoretical_code[len(lines[start_index - 1]) :]
|
||||
theoretical_code = blackify(theoretical_code)
|
||||
|
||||
# Test for a diff and act accordingly.
|
||||
if observed_code != theoretical_code:
|
||||
diff_index = start_index + 1
|
||||
for observed_line, theoretical_line in zip(observed_code.split("\n"), theoretical_code.split("\n")):
|
||||
if observed_line != theoretical_line:
|
||||
break
|
||||
diff_index += 1
|
||||
diffs.append([object_name, diff_index])
|
||||
diff_index = check_codes_match(observed_code, theoretical_code)
|
||||
if diff_index is not None:
|
||||
diffs.append([object_name, diff_index + start_index + 1])
|
||||
if overwrite:
|
||||
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
|
||||
line_index = start_index + 1
|
||||
@@ -259,6 +307,10 @@ def is_copy_consistent(filename, overwrite=False):
|
||||
|
||||
|
||||
def check_copies(overwrite: bool = False):
|
||||
"""
|
||||
Check every file is copy-consistent with the original and maybe `overwrite` content when it is not. Also check the
|
||||
model list in the main README and other READMEs/index.md are consistent.
|
||||
"""
|
||||
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
|
||||
diffs = []
|
||||
for filename in all_files:
|
||||
@@ -275,6 +327,10 @@ def check_copies(overwrite: bool = False):
|
||||
|
||||
|
||||
def check_full_copies(overwrite: bool = False):
|
||||
"""
|
||||
Check the files that are full copies of others (as indicated in `FULL_COPIES`) are copy-consistent and maybe
|
||||
`overwrite` to fix issues.
|
||||
"""
|
||||
diffs = []
|
||||
for target, source in FULL_COPIES.items():
|
||||
with open(source, "r", encoding="utf-8") as f:
|
||||
@@ -299,7 +355,7 @@ def check_full_copies(overwrite: bool = False):
|
||||
|
||||
|
||||
def get_model_list(filename, start_prompt, end_prompt):
|
||||
"""Extracts the model list from the README."""
|
||||
"""Extracts the model list from a README, between `start_prompt` and `end_prompt`."""
|
||||
with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
# Find the start of the list.
|
||||
@@ -327,7 +383,20 @@ def get_model_list(filename, start_prompt, end_prompt):
|
||||
|
||||
|
||||
def convert_to_localized_md(model_list, localized_model_list, format_str):
|
||||
"""Convert `model_list` to each localized README."""
|
||||
"""
|
||||
Compare the model list from the main README to the one in a localized README.
|
||||
|
||||
Args:
|
||||
model_list (`str`): The model list in the main README.
|
||||
localized_model_list (`str`): The model list in one of the localized README.
|
||||
format_str (`str`):
|
||||
The template for a model entry in the localized README (look at the `format_model_list` in the entries of
|
||||
`LOCALIZED_READMES` for examples).
|
||||
|
||||
Returns:
|
||||
`Tuple[bool, str]`: A tuple where the first value indicates if the READMEs match or not, and the second value
|
||||
is the correct localized README.
|
||||
"""
|
||||
|
||||
def _rep(match):
|
||||
title, model_link, paper_affiliations, paper_title_link, paper_authors, supplements = match.groups()
|
||||
@@ -341,7 +410,8 @@ def convert_to_localized_md(model_list, localized_model_list, format_str):
|
||||
)
|
||||
|
||||
# This regex captures metadata from an English model description, including model title, model link,
|
||||
# affiliations of the paper, title of the paper, authors of the paper, and supplemental data (see DistilBERT for example).
|
||||
# affiliations of the paper, title of the paper, authors of the paper, and supplemental data (see DistilBERT for
|
||||
# example).
|
||||
_re_capture_meta = re.compile(
|
||||
r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\* \(from ([^)]*)\)[^\[]*([^\)]*\)).*?by (.*?[A-Za-z\*]{2,}?)\. (.*)$"
|
||||
)
|
||||
@@ -389,6 +459,10 @@ def convert_to_localized_md(model_list, localized_model_list, format_str):
|
||||
|
||||
|
||||
def convert_readme_to_index(model_list):
|
||||
"""
|
||||
Converts the model list of the README to the index.md format.
|
||||
"""
|
||||
# We need to replce both link to the main doc and stable doc (the order of the next two instructions is important).
|
||||
model_list = model_list.replace("https://huggingface.co/docs/transformers/main/", "")
|
||||
return model_list.replace("https://huggingface.co/docs/transformers/", "")
|
||||
|
||||
@@ -420,7 +494,9 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
|
||||
|
||||
|
||||
def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
"""Check the model lists in the README and index.rst are consistent and maybe `overwrite`."""
|
||||
"""
|
||||
Check the model lists in the README is consistent with the ones in the other READMES and also with `index.nmd`.
|
||||
"""
|
||||
# Fix potential doc links in the README
|
||||
with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f:
|
||||
readme = f.read()
|
||||
@@ -490,6 +566,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
)
|
||||
|
||||
|
||||
# Map a model name with the name it has in the README for the check_readme check
|
||||
SPECIAL_MODEL_NAMES = {
|
||||
"Bert Generation": "BERT For Sequence Generation",
|
||||
"BigBird": "BigBird-RoBERTa",
|
||||
@@ -522,7 +599,7 @@ MODELS_NOT_IN_README = [
|
||||
"VisionTextDualEncoder",
|
||||
]
|
||||
|
||||
|
||||
# Template for new entries to add in the main README when we have missing models.
|
||||
README_TEMPLATE = (
|
||||
"1. **[{model_name}](https://huggingface.co/docs/main/transformers/model_doc/{model_type})** (from "
|
||||
"<FILL INSTITUTION>) released with the paper [<FILL PAPER TITLE>](<FILL ARKIV LINK>) by <FILL AUTHORS>."
|
||||
@@ -530,6 +607,10 @@ README_TEMPLATE = (
|
||||
|
||||
|
||||
def check_readme(overwrite=False):
|
||||
"""
|
||||
Check if the main README contains all the models in the library or not. If `overwrite`, will add an entry for the
|
||||
missing models using `README_TEMPLATE`.
|
||||
"""
|
||||
info = LOCALIZED_READMES["README.md"]
|
||||
models, start_index, end_index, lines = _find_text_in_file(
|
||||
os.path.join(REPO_PATH, "README.md"),
|
||||
|
||||
Reference in New Issue
Block a user