Add support for auto_docstring with model outputs (#38242)
* experiment auto_docstring model outputs * Fix PatchTSMixer * Add check model output docstring to check_auto_docstring and fix all model outputs docstring * add reordering of docstring in check_docstrings * add check for redundant docstring in check_docstrings, remove redundant docstrings * refactor check_auto_docstring * make style * fix copies * remove commented code * change List-> list Tuple-> tuple in docstrings * fix modular * make style * Fix modular vipllava --------- Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
@@ -41,6 +41,7 @@ import inspect
|
||||
import operator as op
|
||||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@@ -48,7 +49,14 @@ from check_repo import ignore_undocumented
|
||||
from git import Repo
|
||||
|
||||
from transformers.utils import direct_transformers_import
|
||||
from transformers.utils.args_doc import ImageProcessorArgs, ModelArgs, parse_docstring, set_min_indent, source_args_doc
|
||||
from transformers.utils.args_doc import (
|
||||
ImageProcessorArgs,
|
||||
ModelArgs,
|
||||
ModelOutputArgs,
|
||||
get_args_doc_from_source,
|
||||
parse_docstring,
|
||||
set_min_indent,
|
||||
)
|
||||
|
||||
|
||||
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
|
||||
@@ -64,7 +72,8 @@ _re_args = re.compile(r"^\s*(Args?|Arguments?|Attributes?|Params?|Parameters?):\
|
||||
_re_parse_arg = re.compile(r"^(\s*)(\S+)\s+\((.+)\)(?:\:|$)")
|
||||
# Re pattern that parses the end of a description of an arg (catches the default in *optional*, defaults to xxx).
|
||||
_re_parse_description = re.compile(r"\*optional\*, defaults to (.*)$")
|
||||
|
||||
# Args that are always overridden in the docstring, for clarity we don't want to remove them from the docstring
|
||||
ALWAYS_OVERRIDE = ["labels"]
|
||||
|
||||
# This is a temporary list of objects to ignore while we progressively fix them. Do not add anything here, fix the
|
||||
# docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the
|
||||
@@ -979,6 +988,19 @@ def _find_sig_line(lines, line_end):
|
||||
return sig_line_end
|
||||
|
||||
|
||||
def _find_docstring_end_line(lines, docstring_start_line):
|
||||
if '"""' not in lines[docstring_start_line]:
|
||||
return None
|
||||
docstring_end = docstring_start_line
|
||||
if docstring_start_line is not None:
|
||||
docstring_end = docstring_start_line
|
||||
if not lines[docstring_start_line].count('"""') >= 2:
|
||||
docstring_end += 1
|
||||
while '"""' not in lines[docstring_end]:
|
||||
docstring_end += 1
|
||||
return docstring_end
|
||||
|
||||
|
||||
def find_matching_model_files(check_all: bool = False):
|
||||
"""
|
||||
Find all model files in the transformers repo that should be checked for @auto_docstring,
|
||||
@@ -1074,13 +1096,34 @@ def get_auto_docstring_candidate_lines(lines):
|
||||
return line_starts_candidates, line_ends_candidates
|
||||
|
||||
|
||||
def get_args_in_signature(lines, signature_content):
|
||||
signature_content = [line.split("#")[0] for line in signature_content]
|
||||
signature_content = "".join(signature_content)
|
||||
signature_content = "".join(signature_content.split(")")[:-1])
|
||||
args_in_signature = re.findall(r"[,(]\s*(\w+)\s*(?=:|=|,|\))", signature_content)
|
||||
if "self" in args_in_signature:
|
||||
args_in_signature.remove("self")
|
||||
return args_in_signature
|
||||
|
||||
|
||||
def get_args_in_dataclass(lines, dataclass_content):
|
||||
dataclass_content = [line.split("#")[0] for line in dataclass_content]
|
||||
dataclass_content = "\n".join(dataclass_content)
|
||||
args_in_dataclass = re.findall(r"^ (\w+)(?:\s*:|\s*=|\s*$)", dataclass_content, re.MULTILINE)
|
||||
if "self" in args_in_dataclass:
|
||||
args_in_dataclass.remove("self")
|
||||
return args_in_dataclass
|
||||
|
||||
|
||||
def generate_new_docstring_for_signature(
|
||||
lines,
|
||||
sig_start_line,
|
||||
args_in_signature,
|
||||
sig_end_line,
|
||||
docstring_line,
|
||||
docstring_start_line,
|
||||
arg_indent=" ",
|
||||
output_docstring_indent=8,
|
||||
custom_args_dict={},
|
||||
source_args_doc=[ModelArgs, ImageProcessorArgs],
|
||||
):
|
||||
"""
|
||||
Generalized docstring generator for a function or class signature.
|
||||
@@ -1095,33 +1138,47 @@ def generate_new_docstring_for_signature(
|
||||
"""
|
||||
# Extract and clean signature
|
||||
missing_docstring_args = []
|
||||
docstring_args_ro_remove = []
|
||||
fill_docstring_args = []
|
||||
|
||||
signature_content = lines[sig_start_line:sig_end_line]
|
||||
signature_content = [line.split("#")[0] for line in signature_content]
|
||||
signature_content = "".join(signature_content)
|
||||
signature_content = "".join(signature_content.split(")")[:-1])
|
||||
args_in_signature = re.findall(r"[,(]\s*(\w+)\s*(?=:|=|,|\))", signature_content)
|
||||
if "self" in args_in_signature:
|
||||
args_in_signature.remove("self")
|
||||
# Parse docstring if present
|
||||
args_docstring_dict = {}
|
||||
remaining_docstring = ""
|
||||
docstring_end = sig_end_line - 1
|
||||
if docstring_line is not None:
|
||||
docstring_end = docstring_line
|
||||
if not lines[docstring_line].count('"""') >= 2:
|
||||
docstring_end += 1
|
||||
while '"""' not in lines[docstring_end]:
|
||||
docstring_end += 1
|
||||
docstring_content = lines[docstring_line : docstring_end + 1]
|
||||
if docstring_start_line is not None:
|
||||
docstring_end_line = _find_docstring_end_line(lines, docstring_start_line)
|
||||
docstring_content = lines[docstring_start_line : docstring_end_line + 1]
|
||||
parsed_docstring, remaining_docstring = parse_docstring("\n".join(docstring_content))
|
||||
args_docstring_dict.update(parsed_docstring)
|
||||
else:
|
||||
docstring_end_line = None
|
||||
|
||||
# Remove args that are the same as the ones in the source args doc
|
||||
for arg in args_docstring_dict:
|
||||
if arg in get_args_doc_from_source(source_args_doc) and arg not in ALWAYS_OVERRIDE:
|
||||
source_arg_doc = get_args_doc_from_source(source_args_doc)[arg]
|
||||
if source_arg_doc["description"].strip("\n ") == args_docstring_dict[arg]["description"].strip("\n "):
|
||||
if source_arg_doc.get("shape") is not None and args_docstring_dict[arg].get("shape") is not None:
|
||||
if source_arg_doc.get("shape").strip("\n ") == args_docstring_dict[arg].get("shape").strip("\n "):
|
||||
docstring_args_ro_remove.append(arg)
|
||||
elif (
|
||||
source_arg_doc.get("additional_info") is not None
|
||||
and args_docstring_dict[arg].get("additional_info") is not None
|
||||
):
|
||||
if source_arg_doc.get("additional_info").strip("\n ") == args_docstring_dict[arg].get(
|
||||
"additional_info"
|
||||
).strip("\n "):
|
||||
docstring_args_ro_remove.append(arg)
|
||||
else:
|
||||
docstring_args_ro_remove.append(arg)
|
||||
args_docstring_dict = {
|
||||
arg: args_docstring_dict[arg] for arg in args_docstring_dict if arg not in docstring_args_ro_remove
|
||||
}
|
||||
|
||||
# Fill missing args
|
||||
for arg in args_in_signature:
|
||||
if (
|
||||
arg not in args_docstring_dict
|
||||
and arg not in source_args_doc([ModelArgs, ImageProcessorArgs])
|
||||
and arg not in get_args_doc_from_source(source_args_doc)
|
||||
and arg not in custom_args_dict
|
||||
):
|
||||
missing_docstring_args.append(arg)
|
||||
@@ -1129,22 +1186,33 @@ def generate_new_docstring_for_signature(
|
||||
"type": "<fill_type>",
|
||||
"optional": False,
|
||||
"shape": None,
|
||||
"description": f"\n{arg_indent} <fill_docstring>",
|
||||
"description": "\n <fill_docstring>",
|
||||
"default": None,
|
||||
"additional_info": None,
|
||||
}
|
||||
|
||||
# Handle docstring of inherited args (for dataclasses)
|
||||
ordered_args_docstring_dict = OrderedDict(
|
||||
(arg, args_docstring_dict[arg]) for arg in args_docstring_dict if arg not in args_in_signature
|
||||
)
|
||||
# Add args in the order of the signature
|
||||
ordered_args_docstring_dict.update(
|
||||
(arg, args_docstring_dict[arg]) for arg in args_in_signature if arg in args_docstring_dict
|
||||
)
|
||||
# Build new docstring
|
||||
new_docstring = ""
|
||||
if len(args_docstring_dict) > 0 or remaining_docstring:
|
||||
if len(ordered_args_docstring_dict) > 0 or remaining_docstring:
|
||||
new_docstring += 'r"""\n'
|
||||
for arg in args_docstring_dict:
|
||||
additional_info = args_docstring_dict[arg]["additional_info"] or ""
|
||||
custom_arg_description = args_docstring_dict[arg]["description"]
|
||||
for arg in ordered_args_docstring_dict:
|
||||
additional_info = ordered_args_docstring_dict[arg]["additional_info"] or ""
|
||||
custom_arg_description = ordered_args_docstring_dict[arg]["description"]
|
||||
if "<fill_docstring>" in custom_arg_description and arg not in missing_docstring_args:
|
||||
fill_docstring_args.append(arg)
|
||||
if custom_arg_description.endswith('"""'):
|
||||
custom_arg_description = "\n".join(custom_arg_description.split("\n")[:-1])
|
||||
new_docstring += f"{arg} ({args_docstring_dict[arg]['type']}{additional_info}):{custom_arg_description}\n"
|
||||
new_docstring += (
|
||||
f"{arg} ({ordered_args_docstring_dict[arg]['type']}{additional_info}):{custom_arg_description}\n"
|
||||
)
|
||||
close_docstring = True
|
||||
if remaining_docstring:
|
||||
if remaining_docstring.endswith('"""'):
|
||||
@@ -1153,21 +1221,31 @@ def generate_new_docstring_for_signature(
|
||||
new_docstring += f"{set_min_indent(remaining_docstring, 0)}{end_docstring}"
|
||||
if close_docstring:
|
||||
new_docstring += '"""'
|
||||
new_docstring = set_min_indent(new_docstring, 8)
|
||||
return new_docstring, sig_end_line, docstring_end, missing_docstring_args, fill_docstring_args
|
||||
new_docstring = set_min_indent(new_docstring, output_docstring_indent)
|
||||
|
||||
return (
|
||||
new_docstring,
|
||||
sig_end_line,
|
||||
docstring_end_line if docstring_end_line is not None else sig_end_line - 1,
|
||||
missing_docstring_args,
|
||||
fill_docstring_args,
|
||||
docstring_args_ro_remove,
|
||||
)
|
||||
|
||||
|
||||
def generate_new_docstring_for_function(lines, current_line_end, custom_args_dict):
|
||||
"""
|
||||
Wrapper for function docstring generation using the generalized helper.
|
||||
"""
|
||||
sig_line_end = _find_sig_line(lines, current_line_end)
|
||||
docstring_line = sig_line_end if '"""' in lines[sig_line_end] else None
|
||||
sig_end_line = _find_sig_line(lines, current_line_end)
|
||||
signature_content = lines[current_line_end:sig_end_line]
|
||||
args_in_signature = get_args_in_signature(lines, signature_content)
|
||||
docstring_start_line = sig_end_line if '"""' in lines[sig_end_line] else None
|
||||
return generate_new_docstring_for_signature(
|
||||
lines,
|
||||
current_line_end,
|
||||
sig_line_end,
|
||||
docstring_line,
|
||||
args_in_signature,
|
||||
sig_end_line,
|
||||
docstring_start_line,
|
||||
arg_indent=" ",
|
||||
custom_args_dict=custom_args_dict,
|
||||
)
|
||||
@@ -1178,35 +1256,50 @@ def generate_new_docstring_for_class(lines, current_line_end, custom_args_dict):
|
||||
Wrapper for class docstring generation (via __init__) using the generalized helper.
|
||||
Returns the new docstring and relevant signature/docstring indices.
|
||||
"""
|
||||
init_method_line = current_line_end
|
||||
sig_start_line = current_line_end
|
||||
found_init_method = False
|
||||
while init_method_line < len(lines) - 1 and not found_init_method:
|
||||
init_method_line += 1
|
||||
if " def __init__" in lines[init_method_line]:
|
||||
found_model_output = False
|
||||
while sig_start_line < len(lines) - 1 and not found_init_method:
|
||||
sig_start_line += 1
|
||||
if " def __init__" in lines[sig_start_line]:
|
||||
found_init_method = True
|
||||
elif lines[init_method_line].startswith("class "):
|
||||
elif lines[sig_start_line].startswith("class ") or lines[sig_start_line].startswith("def "):
|
||||
break
|
||||
if not found_init_method:
|
||||
return "", None, None, None, [], []
|
||||
init_method_sig_line_end = _find_sig_line(lines, init_method_line)
|
||||
docstring_line = init_method_sig_line_end if '"""' in lines[init_method_sig_line_end] else None
|
||||
new_docstring, _, init_method_docstring_end, missing_docstring_args, fill_docstring_args = (
|
||||
generate_new_docstring_for_signature(
|
||||
lines,
|
||||
init_method_line,
|
||||
init_method_sig_line_end,
|
||||
docstring_line,
|
||||
arg_indent="",
|
||||
custom_args_dict=custom_args_dict,
|
||||
)
|
||||
)
|
||||
return (
|
||||
new_docstring,
|
||||
init_method_line,
|
||||
init_method_sig_line_end,
|
||||
init_method_docstring_end,
|
||||
missing_docstring_args,
|
||||
fill_docstring_args,
|
||||
if "ModelOutput" in lines[current_line_end]:
|
||||
found_model_output = True
|
||||
sig_start_line = current_line_end
|
||||
else:
|
||||
return "", None, None, [], [], []
|
||||
|
||||
if found_init_method:
|
||||
sig_end_line = _find_sig_line(lines, sig_start_line)
|
||||
signature_content = lines[sig_start_line:sig_end_line]
|
||||
args_in_signature = get_args_in_signature(lines, signature_content)
|
||||
else:
|
||||
# we have a ModelOutput class, the class attributes are the args
|
||||
sig_end_line = sig_start_line + 1
|
||||
docstring_end = _find_docstring_end_line(lines, sig_end_line)
|
||||
model_output_class_start = docstring_end + 1 if docstring_end is not None else sig_end_line - 1
|
||||
model_output_class_end = model_output_class_start
|
||||
while model_output_class_end < len(lines) and (
|
||||
lines[model_output_class_end].startswith(" ") or lines[model_output_class_end] == ""
|
||||
):
|
||||
model_output_class_end += 1
|
||||
dataclass_content = lines[model_output_class_start : model_output_class_end - 1]
|
||||
args_in_signature = get_args_in_dataclass(lines, dataclass_content)
|
||||
|
||||
docstring_start_line = sig_end_line if '"""' in lines[sig_end_line] else None
|
||||
|
||||
return generate_new_docstring_for_signature(
|
||||
lines,
|
||||
args_in_signature,
|
||||
sig_end_line,
|
||||
docstring_start_line,
|
||||
arg_indent="",
|
||||
custom_args_dict=custom_args_dict,
|
||||
output_docstring_indent=4 if found_model_output else 8,
|
||||
source_args_doc=[ModelArgs, ImageProcessorArgs] if not found_model_output else [ModelOutputArgs],
|
||||
)
|
||||
|
||||
|
||||
@@ -1255,8 +1348,9 @@ def update_file_with_new_docstrings(
|
||||
current_line_end = line_ends_candidates[0]
|
||||
index = 1
|
||||
missing_docstring_args_warnings = []
|
||||
|
||||
fill_docstring_args_warnings = []
|
||||
docstring_args_ro_remove_warnings = []
|
||||
|
||||
while index <= len(line_starts_candidates):
|
||||
custom_args_dict = {}
|
||||
auto_docstring_signature_content = "".join(lines[current_line_start:current_line_end])
|
||||
@@ -1267,23 +1361,28 @@ def update_file_with_new_docstrings(
|
||||
if custom_args_var_content:
|
||||
custom_args_dict, _ = parse_docstring(custom_args_var_content)
|
||||
new_docstring = ""
|
||||
found_init_method = False
|
||||
modify_class_docstring = False
|
||||
# Function
|
||||
if " def" in lines[current_line_end]:
|
||||
new_docstring, sig_line_end, docstring_end, missing_docstring_args, fill_docstring_args = (
|
||||
generate_new_docstring_for_function(lines, current_line_end, custom_args_dict)
|
||||
)
|
||||
(
|
||||
new_docstring,
|
||||
sig_line_end,
|
||||
docstring_end,
|
||||
missing_docstring_args,
|
||||
fill_docstring_args,
|
||||
docstring_args_ro_remove,
|
||||
) = generate_new_docstring_for_function(lines, current_line_end, custom_args_dict)
|
||||
# Class
|
||||
elif "class " in lines[current_line_end]:
|
||||
(
|
||||
new_docstring,
|
||||
init_method_line,
|
||||
init_method_sig_line_end,
|
||||
init_method_docstring_end,
|
||||
class_sig_line_end,
|
||||
class_docstring_end_line,
|
||||
missing_docstring_args,
|
||||
fill_docstring_args,
|
||||
docstring_args_ro_remove,
|
||||
) = generate_new_docstring_for_class(lines, current_line_end, custom_args_dict)
|
||||
found_init_method = init_method_line is not None
|
||||
modify_class_docstring = class_sig_line_end is not None
|
||||
# Add warnings if needed
|
||||
if missing_docstring_args:
|
||||
for arg in missing_docstring_args:
|
||||
@@ -1291,7 +1390,9 @@ def update_file_with_new_docstrings(
|
||||
if fill_docstring_args:
|
||||
for arg in fill_docstring_args:
|
||||
fill_docstring_args_warnings.append(f" - {arg} line {current_line_end}")
|
||||
|
||||
if docstring_args_ro_remove:
|
||||
for arg in docstring_args_ro_remove:
|
||||
docstring_args_ro_remove_warnings.append(f" - {arg} line {current_line_end}")
|
||||
# Write new lines
|
||||
if index >= len(line_ends_candidates) or line_ends_candidates[index] > current_line_end:
|
||||
if " def" in lines[current_line_end]:
|
||||
@@ -1302,14 +1403,14 @@ def update_file_with_new_docstrings(
|
||||
content_base_file_new_lines += lines[docstring_end + 1 : line_ends_candidates[index]]
|
||||
else:
|
||||
content_base_file_new_lines += lines[docstring_end + 1 :]
|
||||
elif found_init_method:
|
||||
content_base_file_new_lines += lines[current_line_end:init_method_sig_line_end]
|
||||
elif modify_class_docstring:
|
||||
content_base_file_new_lines += lines[current_line_end:class_sig_line_end]
|
||||
if new_docstring != "":
|
||||
content_base_file_new_lines += new_docstring.split("\n")
|
||||
if index < len(line_ends_candidates):
|
||||
content_base_file_new_lines += lines[init_method_docstring_end + 1 : line_ends_candidates[index]]
|
||||
content_base_file_new_lines += lines[class_docstring_end_line + 1 : line_ends_candidates[index]]
|
||||
else:
|
||||
content_base_file_new_lines += lines[init_method_docstring_end + 1 :]
|
||||
content_base_file_new_lines += lines[class_docstring_end_line + 1 :]
|
||||
elif index < len(line_ends_candidates):
|
||||
content_base_file_new_lines += lines[current_line_end : line_ends_candidates[index]]
|
||||
else:
|
||||
@@ -1323,9 +1424,19 @@ def update_file_with_new_docstrings(
|
||||
with open(candidate_file, "w", encoding="utf-8") as f:
|
||||
f.write(content_base_file_new)
|
||||
|
||||
return missing_docstring_args_warnings, fill_docstring_args_warnings
|
||||
return (
|
||||
missing_docstring_args_warnings,
|
||||
fill_docstring_args_warnings,
|
||||
docstring_args_ro_remove_warnings,
|
||||
)
|
||||
|
||||
|
||||
# TODO (Yoni): The functions in check_auto_docstrings rely on direct code parsing, which is prone to
|
||||
# failure on edge cases and not robust to code changes. While this approach is significantly faster
|
||||
# than using inspect (like in check_docstrings) and allows parsing any object including non-public
|
||||
# ones, it may need to be refactored in the future to use a more robust parsing method. Note that
|
||||
# we still need auto_docstring for some non-public objects since their docstrings are included in the
|
||||
# docs of public objects (e.g. ModelOutput classes).
|
||||
def check_auto_docstrings(overwrite: bool = False, check_all: bool = False):
|
||||
"""
|
||||
Check docstrings of all public objects that are decorated with `@auto_docstrings`.
|
||||
@@ -1343,8 +1454,10 @@ def check_auto_docstrings(overwrite: bool = False, check_all: bool = False):
|
||||
with open(candidate_file, "r", encoding="utf-8") as f:
|
||||
lines = f.read().split("\n")
|
||||
line_starts_candidates, line_ends_candidates = get_auto_docstring_candidate_lines(lines)
|
||||
missing_docstring_args_warnings, fill_docstring_args_warnings = update_file_with_new_docstrings(
|
||||
candidate_file, lines, line_starts_candidates, line_ends_candidates, overwrite=overwrite
|
||||
missing_docstring_args_warnings, fill_docstring_args_warnings, docstring_args_ro_remove_warnings = (
|
||||
update_file_with_new_docstrings(
|
||||
candidate_file, lines, line_starts_candidates, line_ends_candidates, overwrite=overwrite
|
||||
)
|
||||
)
|
||||
if missing_docstring_args_warnings:
|
||||
if not overwrite:
|
||||
@@ -1354,6 +1467,14 @@ def check_auto_docstrings(overwrite: bool = False, check_all: bool = False):
|
||||
print(f"🚨 Missing docstring for the following arguments in {candidate_file}:")
|
||||
for warning in missing_docstring_args_warnings:
|
||||
print(warning)
|
||||
if docstring_args_ro_remove_warnings:
|
||||
if not overwrite:
|
||||
print(
|
||||
"Some docstrings are redundant with the ones in `args_doc.py` and will be removed. Run `make fix-copies` or `python utils/check_docstrings.py --fix_and_overwrite` to remove the redundant docstrings."
|
||||
)
|
||||
print(f"🚨 Redundant docstring for the following arguments in {candidate_file}:")
|
||||
for warning in docstring_args_ro_remove_warnings:
|
||||
print(warning)
|
||||
if fill_docstring_args_warnings:
|
||||
print(f"🚨 Docstring needs to be filled for the following arguments in {candidate_file}:")
|
||||
for warning in fill_docstring_args_warnings:
|
||||
@@ -1418,9 +1539,6 @@ def check_docstrings(overwrite: bool = False, check_all: bool = False):
|
||||
hard_failures.append(name)
|
||||
continue
|
||||
if old_doc != new_doc:
|
||||
print("name", name)
|
||||
print("old_doc", old_doc)
|
||||
print("new_doc", new_doc)
|
||||
if overwrite:
|
||||
fix_docstring(obj, old_doc, new_doc)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user