Copied from for test files (#26713)
* copied statement for test files --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -51,6 +51,7 @@ from transformers.utils import direct_transformers_import
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_copies.py
|
||||
TRANSFORMERS_PATH = "src/transformers"
|
||||
MODEL_TEST_PATH = "tests/models"
|
||||
PATH_TO_DOCS = "docs/source/en"
|
||||
REPO_PATH = "."
|
||||
|
||||
@@ -132,12 +133,15 @@ def _should_continue(line: str, indent: str) -> bool:
|
||||
return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
|
||||
|
||||
|
||||
def find_code_in_transformers(object_name: str) -> str:
|
||||
def find_code_in_transformers(object_name: str, base_path: str = None) -> str:
|
||||
"""
|
||||
Find and return the source code of an object.
|
||||
|
||||
Args:
|
||||
object_name (`str`): The name of the object we want the source code of.
|
||||
object_name (`str`):
|
||||
The name of the object we want the source code of.
|
||||
base_path (`str`, *optional*):
|
||||
The path to the base folder where files are checked. If not set, it will be set to `TRANSFORMERS_PATH`.
|
||||
|
||||
Returns:
|
||||
`str`: The source code of the object.
|
||||
@@ -145,9 +149,21 @@ def find_code_in_transformers(object_name: str) -> str:
|
||||
parts = object_name.split(".")
|
||||
i = 0
|
||||
|
||||
# We can't set this as the default value in the argument, otherwise `CopyCheckTester` will fail, as it uses a
|
||||
# patched temp directory.
|
||||
if base_path is None:
|
||||
base_path = TRANSFORMERS_PATH
|
||||
|
||||
# Detail: the `Copied from` statement is originally designed to work with the last part of `TRANSFORMERS_PATH`,
|
||||
# (which is `transformers`). The same should be applied for `MODEL_TEST_PATH`. However, its last part is `models`
|
||||
# (to only check and search in it) which is a bit confusing. So we keep the copied statement staring with
|
||||
# `tests.models.` and change it to `tests` here.
|
||||
if base_path == MODEL_TEST_PATH:
|
||||
base_path = "tests"
|
||||
|
||||
# First let's find the module where our object lives.
|
||||
module = parts[i]
|
||||
while i < len(parts) and not os.path.isfile(os.path.join(TRANSFORMERS_PATH, f"{module}.py")):
|
||||
while i < len(parts) and not os.path.isfile(os.path.join(base_path, f"{module}.py")):
|
||||
i += 1
|
||||
if i < len(parts):
|
||||
module = os.path.join(module, parts[i])
|
||||
@@ -156,7 +172,7 @@ def find_code_in_transformers(object_name: str) -> str:
|
||||
f"`object_name` should begin with the name of a module of transformers but got {object_name}."
|
||||
)
|
||||
|
||||
with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
with open(os.path.join(base_path, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Now let's find the class / func in the code!
|
||||
@@ -186,6 +202,7 @@ def find_code_in_transformers(object_name: str) -> str:
|
||||
|
||||
|
||||
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
|
||||
_re_copy_warning_for_test_file = re.compile(r"^(\s*)#\s*Copied from\s+tests\.(\S+\.\S+)\s*($|\S.*$)")
|
||||
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
|
||||
_re_fill_pattern = re.compile(r"<FILL\s+[^>]*>")
|
||||
|
||||
@@ -284,14 +301,20 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
|
||||
line_index = 0
|
||||
# Not a for loop cause `lines` is going to change (if `overwrite=True`).
|
||||
while line_index < len(lines):
|
||||
search = _re_copy_warning.search(lines[line_index])
|
||||
search_re = _re_copy_warning
|
||||
if filename.startswith("tests"):
|
||||
search_re = _re_copy_warning_for_test_file
|
||||
|
||||
search = search_re.search(lines[line_index])
|
||||
if search is None:
|
||||
line_index += 1
|
||||
continue
|
||||
|
||||
# There is some copied code here, let's retrieve the original.
|
||||
indent, object_name, replace_pattern = search.groups()
|
||||
theoretical_code = find_code_in_transformers(object_name)
|
||||
|
||||
base_path = TRANSFORMERS_PATH if not filename.startswith("tests") else MODEL_TEST_PATH
|
||||
theoretical_code = find_code_in_transformers(object_name, base_path=base_path)
|
||||
theoretical_indent = get_indent(theoretical_code)
|
||||
|
||||
start_index = line_index + 1 if indent == theoretical_indent else line_index
|
||||
@@ -357,6 +380,9 @@ def check_copies(overwrite: bool = False):
|
||||
Whether or not to overwrite the copies when they don't match.
|
||||
"""
|
||||
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
|
||||
all_test_files = glob.glob(os.path.join(MODEL_TEST_PATH, "**/*.py"), recursive=True)
|
||||
all_files = list(all_files) + list(all_test_files)
|
||||
|
||||
diffs = []
|
||||
for filename in all_files:
|
||||
new_diffs = is_copy_consistent(filename, overwrite)
|
||||
|
||||
Reference in New Issue
Block a user