[AutoDocstring] Based on inspect parsing of the signature (#33771)

* delete common docstring

* nit

* updates

* push

* fixup

* move stuff around fixup

* no need for dataclas

* damn nice modular

* add auto class docstring

* style

* modular update

* import autodocstring

* fixup

* maybe add original doc!

* more cleanup

* remove class do cas well

* update

* nits

* more celanup

* fix

* wups

* small check

* updatez

* some fixes

* fix doc

* update

* nits

* try?

* nit

* some updates

* a little bit better

* where ever we did not have help we are not really adding it!

* revert llama config

* small fixes and small tests

* test

* fixup

* more fix-copies

* updates

* updates

* fix doc building

* style

* small fixes

* nits

* fix-copies

* fix merge issues faster

* fix merge conf

* nits jamba

* ?

* working autodoc for model class and forward except returns and example

* support return section and unpack kwargs description

* nits and cleanup

* fix-copies

* fix-copies

* nits

* Add support for llava-like models

* fixup

* add class args subset support

* add examples inferred from automodel/pipelines

* update ruff

* autodocstring for Aria, Albert + fixups

* Fix empty return blocks

* fix copies

* fix copies

* add autodoc for all fast image processors + align, altclip

* fix copies

* add auto_doc for audio_spectrogram, auto_former, bark, bamba

* Drastically improve speed + add bart beit bert

* add autodoc to all bert-like models

* Fix broken doc

* fix copies

* fix auto_docstring after merge

* add autodoc to models

* add models

* add models

* add models and improve support for optional, and custom shape in args docstring

* update fast image processors

* refactor auto_method_docstring in args_doc

* add models and fix docstring parsing

* add models

* add models

* remove debugging

* add models

* add fix_auto_docstrings and improve args_docs

* add support for additional_info in args docstring

* refactor (almost) all models

* fix check docstring

* fix -copies

* fill in all missing docstrings

* fix copies

* fix qwen3 moe docstring

* add documentation

* add back labels

* update docs and fix can_return_tuple in modular files

* fix LongformerForMaskedLM docstring

* add auto_docstring to _toctree

* remove auto_docstring tests temporarily

* fix copyrights new files

* fix can_return_tuple granite hybrid

* fix fast beit

* Fix empty config doc

* add support for COMMON_CUSTOM_ARGS in check_docstrings and add missing models

* fix code block not closed flava

* fix can_return_tuple sam hq

* Fix Flaubert dataclass

---------

Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
Arthur
2025-05-08 23:46:07 +02:00
committed by GitHub
parent d231f5a7d4
commit 5f5ccfdc54
405 changed files with 18189 additions and 46715 deletions

View File

@@ -36,8 +36,10 @@ like argument descriptions).
import argparse
import ast
import enum
import glob
import inspect
import operator as op
import os
import re
from pathlib import Path
from typing import Any, Optional, Tuple, Union
@@ -46,6 +48,7 @@ from check_repo import ignore_undocumented
from git import Repo
from transformers.utils import direct_transformers_import
from transformers.utils.args_doc import ImageProcessorArgs, ModelArgs, parse_docstring, set_min_indent, source_args_doc
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
@@ -959,6 +962,404 @@ def fix_docstring(obj: Any, old_doc_args: str, new_doc_args: str):
f.write("\n".join(lines))
def _find_sig_line(lines, line_end):
parenthesis_count = 0
sig_line_end = line_end
found_sig = False
while not found_sig:
for char in lines[sig_line_end]:
if char == "(":
parenthesis_count += 1
elif char == ")":
parenthesis_count -= 1
if parenthesis_count == 0:
found_sig = True
break
sig_line_end += 1
return sig_line_end
def find_matching_model_files(check_all: bool = False):
"""
Find all model files in the transformers repo that should be checked for @auto_docstring,
excluding files with certain substrings.
Returns:
List of file paths.
"""
module_diff_files = None
if not check_all:
module_diff_files = set()
repo = Repo(PATH_TO_REPO)
# Diff from index to unstaged files
for modified_file_diff in repo.index.diff(None):
if modified_file_diff.a_path.startswith("src/transformers"):
module_diff_files.add(os.path.join(PATH_TO_REPO, modified_file_diff.a_path))
# Diff from index to `main`
for modified_file_diff in repo.index.diff(repo.refs.main.commit):
if modified_file_diff.a_path.startswith("src/transformers"):
module_diff_files.add(os.path.join(PATH_TO_REPO, modified_file_diff.a_path))
# quick escape route: if there are no module files in the diff, skip this check
if len(module_diff_files) == 0:
return None
modeling_glob_pattern = os.path.join(PATH_TO_TRANSFORMERS, "models/**/modeling_**")
potential_files = glob.glob(modeling_glob_pattern)
image_processing_glob_pattern = os.path.join(PATH_TO_TRANSFORMERS, "models/**/image_processing_*_fast.py")
potential_files += glob.glob(image_processing_glob_pattern)
exclude_substrings = ["modeling_tf_", "modeling_flax_"]
matching_files = []
for file_path in potential_files:
if os.path.isfile(file_path):
filename = os.path.basename(file_path)
is_excluded = any(exclude in filename for exclude in exclude_substrings)
if not is_excluded:
matching_files.append(file_path)
if not check_all:
# intersect with module_diff_files
matching_files = sorted([file for file in matching_files if file in module_diff_files])
print(" Checking auto_docstrings in the following files:" + "\n - " + "\n - ".join(matching_files))
return matching_files
def find_files_with_auto_docstring(matching_files, decorator="@auto_docstring"):
"""
From a list of files, return those that contain the @auto_docstring decorator.
"""
auto_docstrings_files = []
for file_path in matching_files:
with open(file_path, "r", encoding="utf-8") as f:
content_base_file = f.read()
if decorator in content_base_file:
lines = content_base_file.split("\n")
line_numbers = [i for i, line in enumerate(lines) if decorator in line]
for line_number in line_numbers:
line_end = line_number
end_patterns = ["class ", " def"]
stop_condition = False
while line_end < len(lines) and not stop_condition:
line_end += 1
stop_condition = any(lines[line_end].startswith(end_pattern) for end_pattern in end_patterns)
candidate_patterns = ["class ", " def"]
candidate = any(
lines[line_end].startswith(candidate_pattern) for candidate_pattern in candidate_patterns
)
if stop_condition and candidate:
auto_docstrings_files.append(file_path)
break
return auto_docstrings_files
def get_auto_docstring_candidate_lines(lines):
"""
For a file's lines, find the start and end line indices of all @auto_docstring candidates.
Returns two lists: starts and ends.
"""
line_numbers = [i for i, line in enumerate(lines) if "@auto_docstring" in line]
line_starts_candidates = []
line_ends_candidates = []
for line_number in line_numbers:
line_end = line_number
end_patterns = ["class ", " def"]
stop_condition = False
while line_end < len(lines) and not stop_condition:
line_end += 1
stop_condition = any(lines[line_end].startswith(end_pattern) for end_pattern in end_patterns)
candidate_patterns = ["class ", " def"]
candidate = any(lines[line_end].startswith(candidate_pattern) for candidate_pattern in candidate_patterns)
if stop_condition and candidate:
line_ends_candidates.append(line_end)
line_starts_candidates.append(line_number)
return line_starts_candidates, line_ends_candidates
def generate_new_docstring_for_signature(
lines,
sig_start_line,
sig_end_line,
docstring_line,
arg_indent=" ",
custom_args_dict={},
):
"""
Generalized docstring generator for a function or class signature.
Args:
lines: List of lines from the file.
sig_start_line: Line index where the signature starts.
sig_end_line: Line index where the signature ends.
docstring_line: Line index where the docstring starts (or None if not present).
arg_indent: Indentation for missing argument doc entries.
Returns:
new_docstring, sig_end_line, docstring_end (last docstring line index)
"""
# Extract and clean signature
missing_docstring_args = []
fill_docstring_args = []
signature_content = lines[sig_start_line:sig_end_line]
signature_content = [line.split("#")[0] for line in signature_content]
signature_content = "".join(signature_content)
signature_content = "".join(signature_content.split(")")[:-1])
args_in_signature = re.findall(r"[,(]\s*(\w+)\s*(?=:|=|,|\))", signature_content)
if "self" in args_in_signature:
args_in_signature.remove("self")
# Parse docstring if present
args_docstring_dict = {}
remaining_docstring = ""
docstring_end = sig_end_line - 1
if docstring_line is not None:
docstring_end = docstring_line
if not lines[docstring_line].count('"""') >= 2:
docstring_end += 1
while '"""' not in lines[docstring_end]:
docstring_end += 1
docstring_content = lines[docstring_line : docstring_end + 1]
parsed_docstring, remaining_docstring = parse_docstring("\n".join(docstring_content))
args_docstring_dict.update(parsed_docstring)
# Fill missing args
for arg in args_in_signature:
if (
arg not in args_docstring_dict
and arg not in source_args_doc([ModelArgs, ImageProcessorArgs])
and arg not in custom_args_dict
):
missing_docstring_args.append(arg)
args_docstring_dict[arg] = {
"type": "<fill_type>",
"optional": False,
"shape": None,
"description": f"\n{arg_indent} <fill_docstring>",
"default": None,
"additional_info": None,
}
# Build new docstring
new_docstring = ""
if len(args_docstring_dict) > 0 or remaining_docstring:
new_docstring += 'r"""\n'
for arg in args_docstring_dict:
additional_info = args_docstring_dict[arg]["additional_info"] or ""
custom_arg_description = args_docstring_dict[arg]["description"]
if "<fill_docstring>" in custom_arg_description and arg not in missing_docstring_args:
fill_docstring_args.append(arg)
if custom_arg_description.endswith('"""'):
custom_arg_description = "\n".join(custom_arg_description.split("\n")[:-1])
new_docstring += f"{arg} ({args_docstring_dict[arg]['type']}{additional_info}):{custom_arg_description}\n"
close_docstring = True
if remaining_docstring:
if remaining_docstring.endswith('"""'):
close_docstring = False
end_docstring = "\n" if close_docstring else ""
new_docstring += f"{set_min_indent(remaining_docstring, 0)}{end_docstring}"
if close_docstring:
new_docstring += '"""'
new_docstring = set_min_indent(new_docstring, 8)
return new_docstring, sig_end_line, docstring_end, missing_docstring_args, fill_docstring_args
def generate_new_docstring_for_function(lines, current_line_end, custom_args_dict):
"""
Wrapper for function docstring generation using the generalized helper.
"""
sig_line_end = _find_sig_line(lines, current_line_end)
docstring_line = sig_line_end if '"""' in lines[sig_line_end] else None
return generate_new_docstring_for_signature(
lines,
current_line_end,
sig_line_end,
docstring_line,
arg_indent=" ",
custom_args_dict=custom_args_dict,
)
def generate_new_docstring_for_class(lines, current_line_end, custom_args_dict):
"""
Wrapper for class docstring generation (via __init__) using the generalized helper.
Returns the new docstring and relevant signature/docstring indices.
"""
init_method_line = current_line_end
found_init_method = False
while init_method_line < len(lines) - 1 and not found_init_method:
init_method_line += 1
if " def __init__" in lines[init_method_line]:
found_init_method = True
elif lines[init_method_line].startswith("class "):
break
if not found_init_method:
return "", None, None, None, [], []
init_method_sig_line_end = _find_sig_line(lines, init_method_line)
docstring_line = init_method_sig_line_end if '"""' in lines[init_method_sig_line_end] else None
new_docstring, _, init_method_docstring_end, missing_docstring_args, fill_docstring_args = (
generate_new_docstring_for_signature(
lines,
init_method_line,
init_method_sig_line_end,
docstring_line,
arg_indent="",
custom_args_dict=custom_args_dict,
)
)
return (
new_docstring,
init_method_line,
init_method_sig_line_end,
init_method_docstring_end,
missing_docstring_args,
fill_docstring_args,
)
def find_custom_args_with_details(file_content: str, custom_args_var_name: str) -> list[dict]:
"""
Find the given custom args variable in the file content and return its content.
Args:
file_content: The string content of the Python file.
custom_args_var_name: The name of the custom args variable.
"""
# Escape the variable_name to handle any special regex characters it might contain
escaped_variable_name = re.escape(custom_args_var_name)
# Construct the regex pattern dynamically with the specific variable name
# This regex looks for:
# ^\s* : Start of a line with optional leading whitespace.
# ({escaped_variable_name}) : Capture the exact variable name.
# \s*=\s* : An equals sign, surrounded by optional whitespace.
# (r?\"\"\") : Capture the opening triple quotes (raw or normal string).
# (.*?) : Capture the content (non-greedy).
# (\"\"\") : Match the closing triple quotes.
regex_pattern = rf"^\s*({escaped_variable_name})\s*=\s*(r?\"\"\")(.*?)(\"\"\")"
flags = re.MULTILINE | re.DOTALL
# Use re.search to find the first match
match = re.search(regex_pattern, file_content, flags)
if match:
# match.group(1) will be the variable_name itself
# match.group(3) will be the content inside the triple quotes
content = match.group(3).strip()
return content
return None
def update_file_with_new_docstrings(
candidate_file, lines, line_starts_candidates, line_ends_candidates, overwrite=False
):
"""
For a given file, update the docstrings for all @auto_docstring candidates and write the new content.
"""
content_base_file_new_lines = lines[: line_ends_candidates[0]]
current_line_start = line_starts_candidates[0]
current_line_end = line_ends_candidates[0]
index = 1
missing_docstring_args_warnings = []
fill_docstring_args_warnings = []
while index <= len(line_starts_candidates):
custom_args_dict = {}
auto_docstring_signature_content = "".join(lines[current_line_start:current_line_end])
match = re.findall(r"custom_args=(\w+)", auto_docstring_signature_content)
if match:
custom_args_var_name = match[0]
custom_args_var_content = find_custom_args_with_details("\n".join(lines), custom_args_var_name)
if custom_args_var_content:
custom_args_dict, _ = parse_docstring(custom_args_var_content)
new_docstring = ""
found_init_method = False
# Function
if " def" in lines[current_line_end]:
new_docstring, sig_line_end, docstring_end, missing_docstring_args, fill_docstring_args = (
generate_new_docstring_for_function(lines, current_line_end, custom_args_dict)
)
# Class
elif "class " in lines[current_line_end]:
(
new_docstring,
init_method_line,
init_method_sig_line_end,
init_method_docstring_end,
missing_docstring_args,
fill_docstring_args,
) = generate_new_docstring_for_class(lines, current_line_end, custom_args_dict)
found_init_method = init_method_line is not None
# Add warnings if needed
if missing_docstring_args:
for arg in missing_docstring_args:
missing_docstring_args_warnings.append(f" - {arg} line {current_line_end}")
if fill_docstring_args:
for arg in fill_docstring_args:
fill_docstring_args_warnings.append(f" - {arg} line {current_line_end}")
# Write new lines
if index >= len(line_ends_candidates) or line_ends_candidates[index] > current_line_end:
if " def" in lines[current_line_end]:
content_base_file_new_lines += lines[current_line_end:sig_line_end]
if new_docstring != "":
content_base_file_new_lines += new_docstring.split("\n")
if index < len(line_ends_candidates):
content_base_file_new_lines += lines[docstring_end + 1 : line_ends_candidates[index]]
else:
content_base_file_new_lines += lines[docstring_end + 1 :]
elif found_init_method:
content_base_file_new_lines += lines[current_line_end:init_method_sig_line_end]
if new_docstring != "":
content_base_file_new_lines += new_docstring.split("\n")
if index < len(line_ends_candidates):
content_base_file_new_lines += lines[init_method_docstring_end + 1 : line_ends_candidates[index]]
else:
content_base_file_new_lines += lines[init_method_docstring_end + 1 :]
elif index < len(line_ends_candidates):
content_base_file_new_lines += lines[current_line_end : line_ends_candidates[index]]
else:
content_base_file_new_lines += lines[current_line_end:]
if index < len(line_ends_candidates):
current_line_end = line_ends_candidates[index]
current_line_start = line_starts_candidates[index]
index += 1
content_base_file_new = "\n".join(content_base_file_new_lines)
if overwrite:
with open(candidate_file, "w", encoding="utf-8") as f:
f.write(content_base_file_new)
return missing_docstring_args_warnings, fill_docstring_args_warnings
def check_auto_docstrings(overwrite: bool = False, check_all: bool = False):
"""
Check docstrings of all public objects that are decorated with `@auto_docstrings`.
This function orchestrates the process by finding relevant files, scanning for decorators,
generating new docstrings, and updating files as needed.
"""
# 1. Find all model files to check
matching_files = find_matching_model_files(check_all)
if matching_files is None:
return
# 2. Find files that contain the @auto_docstring decorator
auto_docstrings_files = find_files_with_auto_docstring(matching_files)
# 3. For each file, update docstrings for all candidates
for candidate_file in auto_docstrings_files:
with open(candidate_file, "r", encoding="utf-8") as f:
lines = f.read().split("\n")
line_starts_candidates, line_ends_candidates = get_auto_docstring_candidate_lines(lines)
missing_docstring_args_warnings, fill_docstring_args_warnings = update_file_with_new_docstrings(
candidate_file, lines, line_starts_candidates, line_ends_candidates, overwrite=overwrite
)
if missing_docstring_args_warnings:
if not overwrite:
print(
"Some docstrings are missing. Run `make fix-copies` or `python utils/check_docstrings.py --fix_and_overwrite` to generate the docstring templates where needed."
)
print(f"🚨 Missing docstring for the following arguments in {candidate_file}:")
for warning in missing_docstring_args_warnings:
print(warning)
if fill_docstring_args_warnings:
print(f"🚨 Docstring needs to be filled for the following arguments in {candidate_file}:")
for warning in fill_docstring_args_warnings:
print(warning)
def check_docstrings(overwrite: bool = False, check_all: bool = False):
"""
Check docstrings of all public objects that are callables and are documented. By default, only checks the diff.
@@ -1017,6 +1418,9 @@ def check_docstrings(overwrite: bool = False, check_all: bool = False):
hard_failures.append(name)
continue
if old_doc != new_doc:
print("name", name)
print("old_doc", old_doc)
print("new_doc", new_doc)
if overwrite:
fix_docstring(obj, old_doc, new_doc)
else:
@@ -1059,5 +1463,5 @@ if __name__ == "__main__":
"--check_all", action="store_true", help="Whether to check all files. By default, only checks the diff"
)
args = parser.parse_args()
check_auto_docstrings(overwrite=args.fix_and_overwrite, check_all=args.check_all)
check_docstrings(overwrite=args.fix_and_overwrite, check_all=args.check_all)

View File

@@ -258,7 +258,7 @@ def get_docstring_indent(docstring):
return 0
def is_full_docstring(new_docstring: str) -> bool:
def is_full_docstring(original_docstring: str, new_docstring: str, original_level: int) -> bool:
"""Check if `new_docstring` is a full docstring, or if it is only part of a docstring that should then
be merged with the existing old one.
"""
@@ -267,6 +267,17 @@ def is_full_docstring(new_docstring: str) -> bool:
# The docstring contains Args definition, so it is self-contained
if re.search(r"\n\s*Args:\n", new_docstring):
return True
elif re.search(r"\n\s*Args:\n", original_docstring):
return False
# Check if the docstring contains args docstring (meaning it is self contained):
param_pattern = re.compile(
# |--- Group 1 ---|| Group 2 ||- Group 3 -||---------- Group 4 ----------|
rf"^\s{{0,{original_level}}}(\w+)\s*\(\s*([^, \)]*)(\s*.*?)\s*\)\s*:\s*((?:(?!\n^\s{{0,{original_level}}}\w+\s*\().)*)",
re.DOTALL | re.MULTILINE,
)
match_object = param_pattern.search(new_docstring)
if match_object is not None:
return True
# If it contains Returns, but starts with text indented with an additional 4 spaces before, it is self-contained
# (this is the scenario when using `@add_start_docstrings_to_model_forward`, but adding more args to docstring)
match_object = re.search(r"\n([^\S\n]*)Returns:\n", new_docstring)
@@ -280,7 +291,7 @@ def is_full_docstring(new_docstring: str) -> bool:
def merge_docstrings(original_docstring, updated_docstring):
original_level = get_docstring_indent(original_docstring)
if not is_full_docstring(updated_docstring):
if not is_full_docstring(original_docstring, updated_docstring, original_level):
# Split the docstring at the example section, assuming `"""` is used to define the docstring
parts = original_docstring.split("```")
if "```" in updated_docstring and len(parts) > 1:
@@ -291,13 +302,22 @@ def merge_docstrings(original_docstring, updated_docstring):
parts[1] = new_parts[1]
updated_docstring = "".join(
[
parts[0].rstrip(" \n") + new_parts[0],
f"\n{original_level * ' '}```",
parts[1],
"```",
parts[2],
]
)
docstring_opening, original_start_docstring = parts[0].rstrip(" \n").split('"""')[:2]
new_start_docstring = new_parts[0].rstrip(" \n")
docstring_opening += '"""'
if new_start_docstring.startswith(original_start_docstring):
updated_docstring = new_start_docstring + "\n" + updated_docstring
elif original_start_docstring.endswith(new_start_docstring):
updated_docstring = original_start_docstring + "\n" + updated_docstring
else:
updated_docstring = original_start_docstring + "\n" + new_start_docstring + "\n" + updated_docstring
updated_docstring = docstring_opening + updated_docstring
elif updated_docstring not in original_docstring:
# add tabulation if we are at the lowest level.
if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring):