[Modular] skip modular checks based on diff (#36130)
skip modular checks based on diff
This commit is contained in:
@@ -48,7 +48,7 @@ def appear_after(model1: str, model2: str, priority_list: list[str]) -> bool:
|
|||||||
class ConversionOrderTest(unittest.TestCase):
|
class ConversionOrderTest(unittest.TestCase):
|
||||||
def test_conversion_order(self):
|
def test_conversion_order(self):
|
||||||
# Find the order
|
# Find the order
|
||||||
priority_list = create_dependency_mapping.find_priority_list(FILES_TO_PARSE)
|
priority_list, _ = create_dependency_mapping.find_priority_list(FILES_TO_PARSE)
|
||||||
# Extract just the model names
|
# Extract just the model names
|
||||||
model_priority_list = [file.rsplit("modular_")[-1].replace(".py", "") for file in priority_list]
|
model_priority_list = [file.rsplit("modular_")[-1].replace(".py", "") for file in priority_list]
|
||||||
|
|
||||||
|
|||||||
@@ -1024,40 +1024,6 @@ def convert_to_localized_md(model_list: str, localized_model_list: str, format_s
|
|||||||
return readmes_match, "\n".join((x[1] for x in sorted_index)) + "\n"
|
return readmes_match, "\n".join((x[1] for x in sorted_index)) + "\n"
|
||||||
|
|
||||||
|
|
||||||
def _find_text_in_file(filename: str, start_prompt: str, end_prompt: str) -> Tuple[str, int, int, List[str]]:
|
|
||||||
"""
|
|
||||||
Find the text in a file between two prompts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename (`str`): The name of the file to look into.
|
|
||||||
start_prompt (`str`): The string to look for that introduces the content looked for.
|
|
||||||
end_prompt (`str`): The string to look for that ends the content looked for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[str, int, int, List[str]]: The content between the two prompts, the index of the start line in the
|
|
||||||
original file, the index of the end line in the original file and the list of lines of that file.
|
|
||||||
"""
|
|
||||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
# Find the start prompt.
|
|
||||||
start_index = 0
|
|
||||||
while not lines[start_index].startswith(start_prompt):
|
|
||||||
start_index += 1
|
|
||||||
start_index += 1
|
|
||||||
|
|
||||||
end_index = start_index
|
|
||||||
while not lines[end_index].startswith(end_prompt):
|
|
||||||
end_index += 1
|
|
||||||
end_index -= 1
|
|
||||||
|
|
||||||
while len(lines[start_index]) <= 1:
|
|
||||||
start_index += 1
|
|
||||||
while len(lines[end_index]) <= 1:
|
|
||||||
end_index -= 1
|
|
||||||
end_index += 1
|
|
||||||
return "".join(lines[start_index:end_index]), start_index, end_index, lines
|
|
||||||
|
|
||||||
|
|
||||||
# Map a model name with the name it has in the README for the check_readme check
|
# Map a model name with the name it has in the README for the check_readme check
|
||||||
SPECIAL_MODEL_NAMES = {
|
SPECIAL_MODEL_NAMES = {
|
||||||
"Bert Generation": "BERT For Sequence Generation",
|
"Bert Generation": "BERT For Sequence Generation",
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import argparse
|
|||||||
import difflib
|
import difflib
|
||||||
import glob
|
import glob
|
||||||
import logging
|
import logging
|
||||||
|
import subprocess
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
from create_dependency_mapping import find_priority_list
|
from create_dependency_mapping import find_priority_list
|
||||||
@@ -61,6 +62,56 @@ def compare_files(modular_file_path, fix_and_overwrite=False):
|
|||||||
return diff
|
return diff
|
||||||
|
|
||||||
|
|
||||||
|
def get_models_in_diff():
|
||||||
|
"""
|
||||||
|
Finds all models that have been modified in the diff.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A set containing the names of the models that have been modified (e.g. {'llama', 'whisper'}).
|
||||||
|
"""
|
||||||
|
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
|
||||||
|
modified_files = (
|
||||||
|
subprocess.check_output(f"git diff --diff-filter=d --name-only {fork_point_sha}".split())
|
||||||
|
.decode("utf-8")
|
||||||
|
.split()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Matches both modelling files and tests
|
||||||
|
relevant_modified_files = [x for x in modified_files if "/models/" in x and x.endswith(".py")]
|
||||||
|
model_names = set()
|
||||||
|
for file_path in relevant_modified_files:
|
||||||
|
model_name = file_path.split("/")[-2]
|
||||||
|
model_names.add(model_name)
|
||||||
|
return model_names
|
||||||
|
|
||||||
|
|
||||||
|
def guaranteed_no_diff(modular_file_path, dependencies, models_in_diff):
|
||||||
|
"""
|
||||||
|
Returns whether it is guaranteed to have no differences between the modular file and the modeling file.
|
||||||
|
|
||||||
|
Model is in the diff -> not guaranteed to have no differences
|
||||||
|
Dependency is in the diff -> not guaranteed to have no differences
|
||||||
|
Otherwise -> guaranteed to have no differences
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modular_file_path: The path to the modular file.
|
||||||
|
dependencies: A dictionary containing the dependencies of each modular file.
|
||||||
|
models_in_diff: A set containing the names of the models that have been modified.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A boolean indicating whether the model (code and tests) is guaranteed to have no differences.
|
||||||
|
"""
|
||||||
|
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
|
||||||
|
if model_name in models_in_diff:
|
||||||
|
return False
|
||||||
|
for dep in dependencies[modular_file_path]:
|
||||||
|
# two possible patterns: `transformers.models.model_name.(...)` or `model_name.(...)`
|
||||||
|
dependency_model_name = dep.split(".")[-2]
|
||||||
|
if dependency_model_name in models_in_diff:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
|
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -72,9 +123,32 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.files == ["all"]:
|
if args.files == ["all"]:
|
||||||
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
||||||
|
|
||||||
|
# Assuming there is a topological sort on the dependency mapping: if the file being checked and its dependencies
|
||||||
|
# are not in the diff, then there it is guaranteed to have no differences. If no models are in the diff, then this
|
||||||
|
# script will do nothing.
|
||||||
|
models_in_diff = get_models_in_diff()
|
||||||
|
if not models_in_diff:
|
||||||
|
console.print("[bold green]No models files or model tests in the diff, skipping modular checks[/bold green]")
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
skipped_models = set()
|
||||||
non_matching_files = 0
|
non_matching_files = 0
|
||||||
for modular_file_path in find_priority_list(args.files):
|
ordered_files, dependencies = find_priority_list(args.files)
|
||||||
|
for modular_file_path in ordered_files:
|
||||||
|
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
|
||||||
|
if is_guaranteed_no_diff:
|
||||||
|
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
|
||||||
|
skipped_models.add(model_name)
|
||||||
|
continue
|
||||||
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
|
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
|
||||||
|
models_in_diff = get_models_in_diff() # When overwriting, the diff changes
|
||||||
|
|
||||||
if non_matching_files and not args.fix_and_overwrite:
|
if non_matching_files and not args.fix_and_overwrite:
|
||||||
raise ValueError("Some diff and their modeling code did not match.")
|
raise ValueError("Some diff and their modeling code did not match.")
|
||||||
|
|
||||||
|
if skipped_models:
|
||||||
|
console.print(
|
||||||
|
f"[bold green]Skipped {len(skipped_models)} models and their dependencies that are not in the diff: "
|
||||||
|
f"{', '.join(skipped_models)}[/bold green]"
|
||||||
|
)
|
||||||
|
|||||||
@@ -55,6 +55,16 @@ def map_dependencies(py_files):
|
|||||||
|
|
||||||
|
|
||||||
def find_priority_list(py_files):
|
def find_priority_list(py_files):
|
||||||
|
"""
|
||||||
|
Given a list of modular files, sorts them by topological order. Modular models that DON'T depend on other modular
|
||||||
|
models will be higher in the topological order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
py_files: List of paths to the modular files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple with the ordered files (list) and their dependencies (dict)
|
||||||
|
"""
|
||||||
dependencies = map_dependencies(py_files)
|
dependencies = map_dependencies(py_files)
|
||||||
ordered_classes = topological_sort(dependencies)
|
ordered_files = topological_sort(dependencies)
|
||||||
return ordered_classes
|
return ordered_files, dependencies
|
||||||
|
|||||||
@@ -1716,7 +1716,7 @@ if __name__ == "__main__":
|
|||||||
if args.files_to_parse == ["examples"]:
|
if args.files_to_parse == ["examples"]:
|
||||||
args.files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True)
|
args.files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True)
|
||||||
|
|
||||||
priority_list = find_priority_list(args.files_to_parse)
|
priority_list, _ = find_priority_list(args.files_to_parse)
|
||||||
assert len(priority_list) == len(args.files_to_parse), "Some files will not be converted"
|
assert len(priority_list) == len(args.files_to_parse), "Some files will not be converted"
|
||||||
|
|
||||||
for file_name in priority_list:
|
for file_name in priority_list:
|
||||||
|
|||||||
Reference in New Issue
Block a user