[Modular] skip modular checks based on diff (#36130)
skip modular checks based on diff
This commit is contained in:
@@ -2,6 +2,7 @@ import argparse
|
||||
import difflib
|
||||
import glob
|
||||
import logging
|
||||
import subprocess
|
||||
from io import StringIO
|
||||
|
||||
from create_dependency_mapping import find_priority_list
|
||||
@@ -61,6 +62,56 @@ def compare_files(modular_file_path, fix_and_overwrite=False):
|
||||
return diff
|
||||
|
||||
|
||||
def get_models_in_diff():
|
||||
"""
|
||||
Finds all models that have been modified in the diff.
|
||||
|
||||
Returns:
|
||||
A set containing the names of the models that have been modified (e.g. {'llama', 'whisper'}).
|
||||
"""
|
||||
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
|
||||
modified_files = (
|
||||
subprocess.check_output(f"git diff --diff-filter=d --name-only {fork_point_sha}".split())
|
||||
.decode("utf-8")
|
||||
.split()
|
||||
)
|
||||
|
||||
# Matches both modelling files and tests
|
||||
relevant_modified_files = [x for x in modified_files if "/models/" in x and x.endswith(".py")]
|
||||
model_names = set()
|
||||
for file_path in relevant_modified_files:
|
||||
model_name = file_path.split("/")[-2]
|
||||
model_names.add(model_name)
|
||||
return model_names
|
||||
|
||||
|
||||
def guaranteed_no_diff(modular_file_path, dependencies, models_in_diff):
|
||||
"""
|
||||
Returns whether it is guaranteed to have no differences between the modular file and the modeling file.
|
||||
|
||||
Model is in the diff -> not guaranteed to have no differences
|
||||
Dependency is in the diff -> not guaranteed to have no differences
|
||||
Otherwise -> guaranteed to have no differences
|
||||
|
||||
Args:
|
||||
modular_file_path: The path to the modular file.
|
||||
dependencies: A dictionary containing the dependencies of each modular file.
|
||||
models_in_diff: A set containing the names of the models that have been modified.
|
||||
|
||||
Returns:
|
||||
A boolean indicating whether the model (code and tests) is guaranteed to have no differences.
|
||||
"""
|
||||
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
|
||||
if model_name in models_in_diff:
|
||||
return False
|
||||
for dep in dependencies[modular_file_path]:
|
||||
# two possible patterns: `transformers.models.model_name.(...)` or `model_name.(...)`
|
||||
dependency_model_name = dep.split(".")[-2]
|
||||
if dependency_model_name in models_in_diff:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
|
||||
parser.add_argument(
|
||||
@@ -72,9 +123,32 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
if args.files == ["all"]:
|
||||
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
||||
|
||||
# Assuming there is a topological sort on the dependency mapping: if the file being checked and its dependencies
|
||||
# are not in the diff, then there it is guaranteed to have no differences. If no models are in the diff, then this
|
||||
# script will do nothing.
|
||||
models_in_diff = get_models_in_diff()
|
||||
if not models_in_diff:
|
||||
console.print("[bold green]No models files or model tests in the diff, skipping modular checks[/bold green]")
|
||||
exit(0)
|
||||
|
||||
skipped_models = set()
|
||||
non_matching_files = 0
|
||||
for modular_file_path in find_priority_list(args.files):
|
||||
ordered_files, dependencies = find_priority_list(args.files)
|
||||
for modular_file_path in ordered_files:
|
||||
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
|
||||
if is_guaranteed_no_diff:
|
||||
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
|
||||
skipped_models.add(model_name)
|
||||
continue
|
||||
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
|
||||
models_in_diff = get_models_in_diff() # When overwriting, the diff changes
|
||||
|
||||
if non_matching_files and not args.fix_and_overwrite:
|
||||
raise ValueError("Some diff and their modeling code did not match.")
|
||||
|
||||
if skipped_models:
|
||||
console.print(
|
||||
f"[bold green]Skipped {len(skipped_models)} models and their dependencies that are not in the diff: "
|
||||
f"{', '.join(skipped_models)}[/bold green]"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user