Copied from for test files (#26713)

* copied statement for test files

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-10-11 14:12:09 +02:00
committed by GitHub
parent 9f40639292
commit 5334796d20
14 changed files with 127 additions and 45 deletions

View File

@@ -51,6 +51,7 @@ from transformers.utils import direct_transformers_import
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_copies.py
TRANSFORMERS_PATH = "src/transformers"
MODEL_TEST_PATH = "tests/models"
PATH_TO_DOCS = "docs/source/en"
REPO_PATH = "."
@@ -132,12 +133,15 @@ def _should_continue(line: str, indent: str) -> bool:
return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
def find_code_in_transformers(object_name: str) -> str:
def find_code_in_transformers(object_name: str, base_path: str = None) -> str:
"""
Find and return the source code of an object.
Args:
object_name (`str`): The name of the object we want the source code of.
object_name (`str`):
The name of the object we want the source code of.
base_path (`str`, *optional*):
The path to the base folder where files are checked. If not set, it will be set to `TRANSFORMERS_PATH`.
Returns:
`str`: The source code of the object.
@@ -145,9 +149,21 @@ def find_code_in_transformers(object_name: str) -> str:
parts = object_name.split(".")
i = 0
# We can't set this as the default value in the argument, otherwise `CopyCheckTester` will fail, as it uses a
# patched temp directory.
if base_path is None:
base_path = TRANSFORMERS_PATH
# Detail: the `Copied from` statement is originally designed to work with the last part of `TRANSFORMERS_PATH`,
# (which is `transformers`). The same should be applied for `MODEL_TEST_PATH`. However, its last part is `models`
# (to only check and search in it) which is a bit confusing. So we keep the copied statement staring with
# `tests.models.` and change it to `tests` here.
if base_path == MODEL_TEST_PATH:
base_path = "tests"
# First let's find the module where our object lives.
module = parts[i]
while i < len(parts) and not os.path.isfile(os.path.join(TRANSFORMERS_PATH, f"{module}.py")):
while i < len(parts) and not os.path.isfile(os.path.join(base_path, f"{module}.py")):
i += 1
if i < len(parts):
module = os.path.join(module, parts[i])
@@ -156,7 +172,7 @@ def find_code_in_transformers(object_name: str) -> str:
f"`object_name` should begin with the name of a module of transformers but got {object_name}."
)
with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
with open(os.path.join(base_path, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Now let's find the class / func in the code!
@@ -186,6 +202,7 @@ def find_code_in_transformers(object_name: str) -> str:
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
_re_copy_warning_for_test_file = re.compile(r"^(\s*)#\s*Copied from\s+tests\.(\S+\.\S+)\s*($|\S.*$)")
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
_re_fill_pattern = re.compile(r"<FILL\s+[^>]*>")
@@ -284,14 +301,20 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
line_index = 0
# Not a for loop cause `lines` is going to change (if `overwrite=True`).
while line_index < len(lines):
search = _re_copy_warning.search(lines[line_index])
search_re = _re_copy_warning
if filename.startswith("tests"):
search_re = _re_copy_warning_for_test_file
search = search_re.search(lines[line_index])
if search is None:
line_index += 1
continue
# There is some copied code here, let's retrieve the original.
indent, object_name, replace_pattern = search.groups()
theoretical_code = find_code_in_transformers(object_name)
base_path = TRANSFORMERS_PATH if not filename.startswith("tests") else MODEL_TEST_PATH
theoretical_code = find_code_in_transformers(object_name, base_path=base_path)
theoretical_indent = get_indent(theoretical_code)
start_index = line_index + 1 if indent == theoretical_indent else line_index
@@ -357,6 +380,9 @@ def check_copies(overwrite: bool = False):
Whether or not to overwrite the copies when they don't match.
"""
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
all_test_files = glob.glob(os.path.join(MODEL_TEST_PATH, "**/*.py"), recursive=True)
all_files = list(all_files) + list(all_test_files)
diffs = []
for filename in all_files:
new_diffs = is_copy_consistent(filename, overwrite)