[AutoDocstring] Based on inspect parsing of the signature (#33771)
* delete common docstring * nit * updates * push * fixup * move stuff around fixup * no need for dataclas * damn nice modular * add auto class docstring * style * modular update * import autodocstring * fixup * maybe add original doc! * more cleanup * remove class do cas well * update * nits * more celanup * fix * wups * small check * updatez * some fixes * fix doc * update * nits * try? * nit * some updates * a little bit better * where ever we did not have help we are not really adding it! * revert llama config * small fixes and small tests * test * fixup * more fix-copies * updates * updates * fix doc building * style * small fixes * nits * fix-copies * fix merge issues faster * fix merge conf * nits jamba * ? * working autodoc for model class and forward except returns and example * support return section and unpack kwargs description * nits and cleanup * fix-copies * fix-copies * nits * Add support for llava-like models * fixup * add class args subset support * add examples inferred from automodel/pipelines * update ruff * autodocstring for Aria, Albert + fixups * Fix empty return blocks * fix copies * fix copies * add autodoc for all fast image processors + align, altclip * fix copies * add auto_doc for audio_spectrogram, auto_former, bark, bamba * Drastically improve speed + add bart beit bert * add autodoc to all bert-like models * Fix broken doc * fix copies * fix auto_docstring after merge * add autodoc to models * add models * add models * add models and improve support for optional, and custom shape in args docstring * update fast image processors * refactor auto_method_docstring in args_doc * add models and fix docstring parsing * add models * add models * remove debugging * add models * add fix_auto_docstrings and improve args_docs * add support for additional_info in args docstring * refactor (almost) all models * fix check docstring * fix -copies * fill in all missing docstrings * fix copies * fix qwen3 moe docstring * add documentation * add back labels * update docs and fix can_return_tuple in modular files * fix LongformerForMaskedLM docstring * add auto_docstring to _toctree * remove auto_docstring tests temporarily * fix copyrights new files * fix can_return_tuple granite hybrid * fix fast beit * Fix empty config doc * add support for COMMON_CUSTOM_ARGS in check_docstrings and add missing models * fix code block not closed flava * fix can_return_tuple sam hq * Fix Flaubert dataclass --------- Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co> Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
@@ -36,8 +36,10 @@ like argument descriptions).
|
||||
import argparse
|
||||
import ast
|
||||
import enum
|
||||
import glob
|
||||
import inspect
|
||||
import operator as op
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
@@ -46,6 +48,7 @@ from check_repo import ignore_undocumented
|
||||
from git import Repo
|
||||
|
||||
from transformers.utils import direct_transformers_import
|
||||
from transformers.utils.args_doc import ImageProcessorArgs, ModelArgs, parse_docstring, set_min_indent, source_args_doc
|
||||
|
||||
|
||||
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
|
||||
@@ -959,6 +962,404 @@ def fix_docstring(obj: Any, old_doc_args: str, new_doc_args: str):
|
||||
f.write("\n".join(lines))
|
||||
|
||||
|
||||
def _find_sig_line(lines, line_end):
|
||||
parenthesis_count = 0
|
||||
sig_line_end = line_end
|
||||
found_sig = False
|
||||
while not found_sig:
|
||||
for char in lines[sig_line_end]:
|
||||
if char == "(":
|
||||
parenthesis_count += 1
|
||||
elif char == ")":
|
||||
parenthesis_count -= 1
|
||||
if parenthesis_count == 0:
|
||||
found_sig = True
|
||||
break
|
||||
sig_line_end += 1
|
||||
return sig_line_end
|
||||
|
||||
|
||||
def find_matching_model_files(check_all: bool = False):
|
||||
"""
|
||||
Find all model files in the transformers repo that should be checked for @auto_docstring,
|
||||
excluding files with certain substrings.
|
||||
Returns:
|
||||
List of file paths.
|
||||
"""
|
||||
module_diff_files = None
|
||||
if not check_all:
|
||||
module_diff_files = set()
|
||||
repo = Repo(PATH_TO_REPO)
|
||||
# Diff from index to unstaged files
|
||||
for modified_file_diff in repo.index.diff(None):
|
||||
if modified_file_diff.a_path.startswith("src/transformers"):
|
||||
module_diff_files.add(os.path.join(PATH_TO_REPO, modified_file_diff.a_path))
|
||||
# Diff from index to `main`
|
||||
for modified_file_diff in repo.index.diff(repo.refs.main.commit):
|
||||
if modified_file_diff.a_path.startswith("src/transformers"):
|
||||
module_diff_files.add(os.path.join(PATH_TO_REPO, modified_file_diff.a_path))
|
||||
# quick escape route: if there are no module files in the diff, skip this check
|
||||
if len(module_diff_files) == 0:
|
||||
return None
|
||||
|
||||
modeling_glob_pattern = os.path.join(PATH_TO_TRANSFORMERS, "models/**/modeling_**")
|
||||
potential_files = glob.glob(modeling_glob_pattern)
|
||||
image_processing_glob_pattern = os.path.join(PATH_TO_TRANSFORMERS, "models/**/image_processing_*_fast.py")
|
||||
potential_files += glob.glob(image_processing_glob_pattern)
|
||||
exclude_substrings = ["modeling_tf_", "modeling_flax_"]
|
||||
matching_files = []
|
||||
for file_path in potential_files:
|
||||
if os.path.isfile(file_path):
|
||||
filename = os.path.basename(file_path)
|
||||
is_excluded = any(exclude in filename for exclude in exclude_substrings)
|
||||
if not is_excluded:
|
||||
matching_files.append(file_path)
|
||||
if not check_all:
|
||||
# intersect with module_diff_files
|
||||
matching_files = sorted([file for file in matching_files if file in module_diff_files])
|
||||
|
||||
print(" Checking auto_docstrings in the following files:" + "\n - " + "\n - ".join(matching_files))
|
||||
|
||||
return matching_files
|
||||
|
||||
|
||||
def find_files_with_auto_docstring(matching_files, decorator="@auto_docstring"):
|
||||
"""
|
||||
From a list of files, return those that contain the @auto_docstring decorator.
|
||||
"""
|
||||
auto_docstrings_files = []
|
||||
for file_path in matching_files:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content_base_file = f.read()
|
||||
if decorator in content_base_file:
|
||||
lines = content_base_file.split("\n")
|
||||
line_numbers = [i for i, line in enumerate(lines) if decorator in line]
|
||||
for line_number in line_numbers:
|
||||
line_end = line_number
|
||||
end_patterns = ["class ", " def"]
|
||||
stop_condition = False
|
||||
while line_end < len(lines) and not stop_condition:
|
||||
line_end += 1
|
||||
stop_condition = any(lines[line_end].startswith(end_pattern) for end_pattern in end_patterns)
|
||||
candidate_patterns = ["class ", " def"]
|
||||
candidate = any(
|
||||
lines[line_end].startswith(candidate_pattern) for candidate_pattern in candidate_patterns
|
||||
)
|
||||
if stop_condition and candidate:
|
||||
auto_docstrings_files.append(file_path)
|
||||
break
|
||||
return auto_docstrings_files
|
||||
|
||||
|
||||
def get_auto_docstring_candidate_lines(lines):
|
||||
"""
|
||||
For a file's lines, find the start and end line indices of all @auto_docstring candidates.
|
||||
Returns two lists: starts and ends.
|
||||
"""
|
||||
line_numbers = [i for i, line in enumerate(lines) if "@auto_docstring" in line]
|
||||
line_starts_candidates = []
|
||||
line_ends_candidates = []
|
||||
for line_number in line_numbers:
|
||||
line_end = line_number
|
||||
end_patterns = ["class ", " def"]
|
||||
stop_condition = False
|
||||
while line_end < len(lines) and not stop_condition:
|
||||
line_end += 1
|
||||
stop_condition = any(lines[line_end].startswith(end_pattern) for end_pattern in end_patterns)
|
||||
candidate_patterns = ["class ", " def"]
|
||||
candidate = any(lines[line_end].startswith(candidate_pattern) for candidate_pattern in candidate_patterns)
|
||||
if stop_condition and candidate:
|
||||
line_ends_candidates.append(line_end)
|
||||
line_starts_candidates.append(line_number)
|
||||
return line_starts_candidates, line_ends_candidates
|
||||
|
||||
|
||||
def generate_new_docstring_for_signature(
|
||||
lines,
|
||||
sig_start_line,
|
||||
sig_end_line,
|
||||
docstring_line,
|
||||
arg_indent=" ",
|
||||
custom_args_dict={},
|
||||
):
|
||||
"""
|
||||
Generalized docstring generator for a function or class signature.
|
||||
Args:
|
||||
lines: List of lines from the file.
|
||||
sig_start_line: Line index where the signature starts.
|
||||
sig_end_line: Line index where the signature ends.
|
||||
docstring_line: Line index where the docstring starts (or None if not present).
|
||||
arg_indent: Indentation for missing argument doc entries.
|
||||
Returns:
|
||||
new_docstring, sig_end_line, docstring_end (last docstring line index)
|
||||
"""
|
||||
# Extract and clean signature
|
||||
missing_docstring_args = []
|
||||
fill_docstring_args = []
|
||||
|
||||
signature_content = lines[sig_start_line:sig_end_line]
|
||||
signature_content = [line.split("#")[0] for line in signature_content]
|
||||
signature_content = "".join(signature_content)
|
||||
signature_content = "".join(signature_content.split(")")[:-1])
|
||||
args_in_signature = re.findall(r"[,(]\s*(\w+)\s*(?=:|=|,|\))", signature_content)
|
||||
if "self" in args_in_signature:
|
||||
args_in_signature.remove("self")
|
||||
# Parse docstring if present
|
||||
args_docstring_dict = {}
|
||||
remaining_docstring = ""
|
||||
docstring_end = sig_end_line - 1
|
||||
if docstring_line is not None:
|
||||
docstring_end = docstring_line
|
||||
if not lines[docstring_line].count('"""') >= 2:
|
||||
docstring_end += 1
|
||||
while '"""' not in lines[docstring_end]:
|
||||
docstring_end += 1
|
||||
docstring_content = lines[docstring_line : docstring_end + 1]
|
||||
parsed_docstring, remaining_docstring = parse_docstring("\n".join(docstring_content))
|
||||
args_docstring_dict.update(parsed_docstring)
|
||||
# Fill missing args
|
||||
for arg in args_in_signature:
|
||||
if (
|
||||
arg not in args_docstring_dict
|
||||
and arg not in source_args_doc([ModelArgs, ImageProcessorArgs])
|
||||
and arg not in custom_args_dict
|
||||
):
|
||||
missing_docstring_args.append(arg)
|
||||
args_docstring_dict[arg] = {
|
||||
"type": "<fill_type>",
|
||||
"optional": False,
|
||||
"shape": None,
|
||||
"description": f"\n{arg_indent} <fill_docstring>",
|
||||
"default": None,
|
||||
"additional_info": None,
|
||||
}
|
||||
# Build new docstring
|
||||
new_docstring = ""
|
||||
if len(args_docstring_dict) > 0 or remaining_docstring:
|
||||
new_docstring += 'r"""\n'
|
||||
for arg in args_docstring_dict:
|
||||
additional_info = args_docstring_dict[arg]["additional_info"] or ""
|
||||
custom_arg_description = args_docstring_dict[arg]["description"]
|
||||
if "<fill_docstring>" in custom_arg_description and arg not in missing_docstring_args:
|
||||
fill_docstring_args.append(arg)
|
||||
if custom_arg_description.endswith('"""'):
|
||||
custom_arg_description = "\n".join(custom_arg_description.split("\n")[:-1])
|
||||
new_docstring += f"{arg} ({args_docstring_dict[arg]['type']}{additional_info}):{custom_arg_description}\n"
|
||||
close_docstring = True
|
||||
if remaining_docstring:
|
||||
if remaining_docstring.endswith('"""'):
|
||||
close_docstring = False
|
||||
end_docstring = "\n" if close_docstring else ""
|
||||
new_docstring += f"{set_min_indent(remaining_docstring, 0)}{end_docstring}"
|
||||
if close_docstring:
|
||||
new_docstring += '"""'
|
||||
new_docstring = set_min_indent(new_docstring, 8)
|
||||
return new_docstring, sig_end_line, docstring_end, missing_docstring_args, fill_docstring_args
|
||||
|
||||
|
||||
def generate_new_docstring_for_function(lines, current_line_end, custom_args_dict):
|
||||
"""
|
||||
Wrapper for function docstring generation using the generalized helper.
|
||||
"""
|
||||
sig_line_end = _find_sig_line(lines, current_line_end)
|
||||
docstring_line = sig_line_end if '"""' in lines[sig_line_end] else None
|
||||
return generate_new_docstring_for_signature(
|
||||
lines,
|
||||
current_line_end,
|
||||
sig_line_end,
|
||||
docstring_line,
|
||||
arg_indent=" ",
|
||||
custom_args_dict=custom_args_dict,
|
||||
)
|
||||
|
||||
|
||||
def generate_new_docstring_for_class(lines, current_line_end, custom_args_dict):
|
||||
"""
|
||||
Wrapper for class docstring generation (via __init__) using the generalized helper.
|
||||
Returns the new docstring and relevant signature/docstring indices.
|
||||
"""
|
||||
init_method_line = current_line_end
|
||||
found_init_method = False
|
||||
while init_method_line < len(lines) - 1 and not found_init_method:
|
||||
init_method_line += 1
|
||||
if " def __init__" in lines[init_method_line]:
|
||||
found_init_method = True
|
||||
elif lines[init_method_line].startswith("class "):
|
||||
break
|
||||
if not found_init_method:
|
||||
return "", None, None, None, [], []
|
||||
init_method_sig_line_end = _find_sig_line(lines, init_method_line)
|
||||
docstring_line = init_method_sig_line_end if '"""' in lines[init_method_sig_line_end] else None
|
||||
new_docstring, _, init_method_docstring_end, missing_docstring_args, fill_docstring_args = (
|
||||
generate_new_docstring_for_signature(
|
||||
lines,
|
||||
init_method_line,
|
||||
init_method_sig_line_end,
|
||||
docstring_line,
|
||||
arg_indent="",
|
||||
custom_args_dict=custom_args_dict,
|
||||
)
|
||||
)
|
||||
return (
|
||||
new_docstring,
|
||||
init_method_line,
|
||||
init_method_sig_line_end,
|
||||
init_method_docstring_end,
|
||||
missing_docstring_args,
|
||||
fill_docstring_args,
|
||||
)
|
||||
|
||||
|
||||
def find_custom_args_with_details(file_content: str, custom_args_var_name: str) -> list[dict]:
|
||||
"""
|
||||
Find the given custom args variable in the file content and return its content.
|
||||
|
||||
Args:
|
||||
file_content: The string content of the Python file.
|
||||
custom_args_var_name: The name of the custom args variable.
|
||||
"""
|
||||
# Escape the variable_name to handle any special regex characters it might contain
|
||||
escaped_variable_name = re.escape(custom_args_var_name)
|
||||
|
||||
# Construct the regex pattern dynamically with the specific variable name
|
||||
# This regex looks for:
|
||||
# ^\s* : Start of a line with optional leading whitespace.
|
||||
# ({escaped_variable_name}) : Capture the exact variable name.
|
||||
# \s*=\s* : An equals sign, surrounded by optional whitespace.
|
||||
# (r?\"\"\") : Capture the opening triple quotes (raw or normal string).
|
||||
# (.*?) : Capture the content (non-greedy).
|
||||
# (\"\"\") : Match the closing triple quotes.
|
||||
regex_pattern = rf"^\s*({escaped_variable_name})\s*=\s*(r?\"\"\")(.*?)(\"\"\")"
|
||||
|
||||
flags = re.MULTILINE | re.DOTALL
|
||||
|
||||
# Use re.search to find the first match
|
||||
match = re.search(regex_pattern, file_content, flags)
|
||||
|
||||
if match:
|
||||
# match.group(1) will be the variable_name itself
|
||||
# match.group(3) will be the content inside the triple quotes
|
||||
content = match.group(3).strip()
|
||||
return content
|
||||
return None
|
||||
|
||||
|
||||
def update_file_with_new_docstrings(
|
||||
candidate_file, lines, line_starts_candidates, line_ends_candidates, overwrite=False
|
||||
):
|
||||
"""
|
||||
For a given file, update the docstrings for all @auto_docstring candidates and write the new content.
|
||||
"""
|
||||
content_base_file_new_lines = lines[: line_ends_candidates[0]]
|
||||
current_line_start = line_starts_candidates[0]
|
||||
current_line_end = line_ends_candidates[0]
|
||||
index = 1
|
||||
missing_docstring_args_warnings = []
|
||||
|
||||
fill_docstring_args_warnings = []
|
||||
while index <= len(line_starts_candidates):
|
||||
custom_args_dict = {}
|
||||
auto_docstring_signature_content = "".join(lines[current_line_start:current_line_end])
|
||||
match = re.findall(r"custom_args=(\w+)", auto_docstring_signature_content)
|
||||
if match:
|
||||
custom_args_var_name = match[0]
|
||||
custom_args_var_content = find_custom_args_with_details("\n".join(lines), custom_args_var_name)
|
||||
if custom_args_var_content:
|
||||
custom_args_dict, _ = parse_docstring(custom_args_var_content)
|
||||
new_docstring = ""
|
||||
found_init_method = False
|
||||
# Function
|
||||
if " def" in lines[current_line_end]:
|
||||
new_docstring, sig_line_end, docstring_end, missing_docstring_args, fill_docstring_args = (
|
||||
generate_new_docstring_for_function(lines, current_line_end, custom_args_dict)
|
||||
)
|
||||
# Class
|
||||
elif "class " in lines[current_line_end]:
|
||||
(
|
||||
new_docstring,
|
||||
init_method_line,
|
||||
init_method_sig_line_end,
|
||||
init_method_docstring_end,
|
||||
missing_docstring_args,
|
||||
fill_docstring_args,
|
||||
) = generate_new_docstring_for_class(lines, current_line_end, custom_args_dict)
|
||||
found_init_method = init_method_line is not None
|
||||
# Add warnings if needed
|
||||
if missing_docstring_args:
|
||||
for arg in missing_docstring_args:
|
||||
missing_docstring_args_warnings.append(f" - {arg} line {current_line_end}")
|
||||
if fill_docstring_args:
|
||||
for arg in fill_docstring_args:
|
||||
fill_docstring_args_warnings.append(f" - {arg} line {current_line_end}")
|
||||
|
||||
# Write new lines
|
||||
if index >= len(line_ends_candidates) or line_ends_candidates[index] > current_line_end:
|
||||
if " def" in lines[current_line_end]:
|
||||
content_base_file_new_lines += lines[current_line_end:sig_line_end]
|
||||
if new_docstring != "":
|
||||
content_base_file_new_lines += new_docstring.split("\n")
|
||||
if index < len(line_ends_candidates):
|
||||
content_base_file_new_lines += lines[docstring_end + 1 : line_ends_candidates[index]]
|
||||
else:
|
||||
content_base_file_new_lines += lines[docstring_end + 1 :]
|
||||
elif found_init_method:
|
||||
content_base_file_new_lines += lines[current_line_end:init_method_sig_line_end]
|
||||
if new_docstring != "":
|
||||
content_base_file_new_lines += new_docstring.split("\n")
|
||||
if index < len(line_ends_candidates):
|
||||
content_base_file_new_lines += lines[init_method_docstring_end + 1 : line_ends_candidates[index]]
|
||||
else:
|
||||
content_base_file_new_lines += lines[init_method_docstring_end + 1 :]
|
||||
elif index < len(line_ends_candidates):
|
||||
content_base_file_new_lines += lines[current_line_end : line_ends_candidates[index]]
|
||||
else:
|
||||
content_base_file_new_lines += lines[current_line_end:]
|
||||
if index < len(line_ends_candidates):
|
||||
current_line_end = line_ends_candidates[index]
|
||||
current_line_start = line_starts_candidates[index]
|
||||
index += 1
|
||||
content_base_file_new = "\n".join(content_base_file_new_lines)
|
||||
if overwrite:
|
||||
with open(candidate_file, "w", encoding="utf-8") as f:
|
||||
f.write(content_base_file_new)
|
||||
|
||||
return missing_docstring_args_warnings, fill_docstring_args_warnings
|
||||
|
||||
|
||||
def check_auto_docstrings(overwrite: bool = False, check_all: bool = False):
|
||||
"""
|
||||
Check docstrings of all public objects that are decorated with `@auto_docstrings`.
|
||||
This function orchestrates the process by finding relevant files, scanning for decorators,
|
||||
generating new docstrings, and updating files as needed.
|
||||
"""
|
||||
# 1. Find all model files to check
|
||||
matching_files = find_matching_model_files(check_all)
|
||||
if matching_files is None:
|
||||
return
|
||||
# 2. Find files that contain the @auto_docstring decorator
|
||||
auto_docstrings_files = find_files_with_auto_docstring(matching_files)
|
||||
# 3. For each file, update docstrings for all candidates
|
||||
for candidate_file in auto_docstrings_files:
|
||||
with open(candidate_file, "r", encoding="utf-8") as f:
|
||||
lines = f.read().split("\n")
|
||||
line_starts_candidates, line_ends_candidates = get_auto_docstring_candidate_lines(lines)
|
||||
missing_docstring_args_warnings, fill_docstring_args_warnings = update_file_with_new_docstrings(
|
||||
candidate_file, lines, line_starts_candidates, line_ends_candidates, overwrite=overwrite
|
||||
)
|
||||
if missing_docstring_args_warnings:
|
||||
if not overwrite:
|
||||
print(
|
||||
"Some docstrings are missing. Run `make fix-copies` or `python utils/check_docstrings.py --fix_and_overwrite` to generate the docstring templates where needed."
|
||||
)
|
||||
print(f"🚨 Missing docstring for the following arguments in {candidate_file}:")
|
||||
for warning in missing_docstring_args_warnings:
|
||||
print(warning)
|
||||
if fill_docstring_args_warnings:
|
||||
print(f"🚨 Docstring needs to be filled for the following arguments in {candidate_file}:")
|
||||
for warning in fill_docstring_args_warnings:
|
||||
print(warning)
|
||||
|
||||
|
||||
def check_docstrings(overwrite: bool = False, check_all: bool = False):
|
||||
"""
|
||||
Check docstrings of all public objects that are callables and are documented. By default, only checks the diff.
|
||||
@@ -1017,6 +1418,9 @@ def check_docstrings(overwrite: bool = False, check_all: bool = False):
|
||||
hard_failures.append(name)
|
||||
continue
|
||||
if old_doc != new_doc:
|
||||
print("name", name)
|
||||
print("old_doc", old_doc)
|
||||
print("new_doc", new_doc)
|
||||
if overwrite:
|
||||
fix_docstring(obj, old_doc, new_doc)
|
||||
else:
|
||||
@@ -1059,5 +1463,5 @@ if __name__ == "__main__":
|
||||
"--check_all", action="store_true", help="Whether to check all files. By default, only checks the diff"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
check_auto_docstrings(overwrite=args.fix_and_overwrite, check_all=args.check_all)
|
||||
check_docstrings(overwrite=args.fix_and_overwrite, check_all=args.check_all)
|
||||
|
||||
@@ -258,7 +258,7 @@ def get_docstring_indent(docstring):
|
||||
return 0
|
||||
|
||||
|
||||
def is_full_docstring(new_docstring: str) -> bool:
|
||||
def is_full_docstring(original_docstring: str, new_docstring: str, original_level: int) -> bool:
|
||||
"""Check if `new_docstring` is a full docstring, or if it is only part of a docstring that should then
|
||||
be merged with the existing old one.
|
||||
"""
|
||||
@@ -267,6 +267,17 @@ def is_full_docstring(new_docstring: str) -> bool:
|
||||
# The docstring contains Args definition, so it is self-contained
|
||||
if re.search(r"\n\s*Args:\n", new_docstring):
|
||||
return True
|
||||
elif re.search(r"\n\s*Args:\n", original_docstring):
|
||||
return False
|
||||
# Check if the docstring contains args docstring (meaning it is self contained):
|
||||
param_pattern = re.compile(
|
||||
# |--- Group 1 ---|| Group 2 ||- Group 3 -||---------- Group 4 ----------|
|
||||
rf"^\s{{0,{original_level}}}(\w+)\s*\(\s*([^, \)]*)(\s*.*?)\s*\)\s*:\s*((?:(?!\n^\s{{0,{original_level}}}\w+\s*\().)*)",
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
match_object = param_pattern.search(new_docstring)
|
||||
if match_object is not None:
|
||||
return True
|
||||
# If it contains Returns, but starts with text indented with an additional 4 spaces before, it is self-contained
|
||||
# (this is the scenario when using `@add_start_docstrings_to_model_forward`, but adding more args to docstring)
|
||||
match_object = re.search(r"\n([^\S\n]*)Returns:\n", new_docstring)
|
||||
@@ -280,7 +291,7 @@ def is_full_docstring(new_docstring: str) -> bool:
|
||||
|
||||
def merge_docstrings(original_docstring, updated_docstring):
|
||||
original_level = get_docstring_indent(original_docstring)
|
||||
if not is_full_docstring(updated_docstring):
|
||||
if not is_full_docstring(original_docstring, updated_docstring, original_level):
|
||||
# Split the docstring at the example section, assuming `"""` is used to define the docstring
|
||||
parts = original_docstring.split("```")
|
||||
if "```" in updated_docstring and len(parts) > 1:
|
||||
@@ -291,13 +302,22 @@ def merge_docstrings(original_docstring, updated_docstring):
|
||||
parts[1] = new_parts[1]
|
||||
updated_docstring = "".join(
|
||||
[
|
||||
parts[0].rstrip(" \n") + new_parts[0],
|
||||
f"\n{original_level * ' '}```",
|
||||
parts[1],
|
||||
"```",
|
||||
parts[2],
|
||||
]
|
||||
)
|
||||
docstring_opening, original_start_docstring = parts[0].rstrip(" \n").split('"""')[:2]
|
||||
new_start_docstring = new_parts[0].rstrip(" \n")
|
||||
docstring_opening += '"""'
|
||||
if new_start_docstring.startswith(original_start_docstring):
|
||||
updated_docstring = new_start_docstring + "\n" + updated_docstring
|
||||
elif original_start_docstring.endswith(new_start_docstring):
|
||||
updated_docstring = original_start_docstring + "\n" + updated_docstring
|
||||
else:
|
||||
updated_docstring = original_start_docstring + "\n" + new_start_docstring + "\n" + updated_docstring
|
||||
updated_docstring = docstring_opening + updated_docstring
|
||||
elif updated_docstring not in original_docstring:
|
||||
# add tabulation if we are at the lowest level.
|
||||
if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring):
|
||||
|
||||
Reference in New Issue
Block a user