[modular] speedup check_modular_conversion with multiprocessing (#37456)

* Change topological sort to return level-based output (lists of lists)

* Update main for modular converter

* Update test

* update check_modular_conversion

* Update gitignore

* Fix missing conversion for glm4

* Update

* Fix error msg

* Fixup

* fix docstring

* update docs

* Add comment

* delete qwen3_moe
This commit is contained in:
Pavel Iakubovskii
2025-07-10 19:07:59 +01:00
committed by GitHub
parent 571a8c2131
commit fe1a5b73e6
5 changed files with 125 additions and 45 deletions

View File

@@ -2,7 +2,11 @@ import argparse
import difflib
import glob
import logging
import multiprocessing
import os
import shutil
import subprocess
from functools import partial
from io import StringIO
from create_dependency_mapping import find_priority_list
@@ -17,8 +21,15 @@ logging.basicConfig()
logging.getLogger().setLevel(logging.ERROR)
console = Console()
BACKUP_EXT = ".modular_backup"
def process_file(modular_file_path, generated_modeling_content, file_type="modeling_", fix_and_overwrite=False):
def process_file(
modular_file_path,
generated_modeling_content,
file_type="modeling_",
show_diff=True,
):
file_name_prefix = file_type.split("*")[0]
file_name_suffix = file_type.split("*")[-1] if "*" in file_type else ""
file_path = modular_file_path.replace("modular_", f"{file_name_prefix}_").replace(".py", f"{file_name_suffix}.py")
@@ -38,11 +49,14 @@ def process_file(modular_file_path, generated_modeling_content, file_type="model
diff_list = list(diff)
# Check for differences
if diff_list:
if fix_and_overwrite:
with open(file_path, "w", encoding="utf-8", newline="\n") as modeling_file:
modeling_file.write(generated_modeling_content[file_type][0])
console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]")
else:
# first save the copy of the original file, to be able to restore it later
if os.path.exists(file_path):
shutil.copy(file_path, file_path + BACKUP_EXT)
# we always save the generated content, to be able to update dependant files
with open(file_path, "w", encoding="utf-8", newline="\n") as modeling_file:
modeling_file.write(generated_modeling_content[file_type][0])
console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]")
if show_diff:
console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n")
diff_text = "\n".join(diff_list)
syntax = Syntax(diff_text, "diff", theme="ansi_dark", line_numbers=True)
@@ -53,12 +67,12 @@ def process_file(modular_file_path, generated_modeling_content, file_type="model
return 0
def compare_files(modular_file_path, fix_and_overwrite=False):
def compare_files(modular_file_path, show_diff=True):
# Generate the expected modeling content
generated_modeling_content = convert_modular_file(modular_file_path)
diff = 0
for file_type in generated_modeling_content.keys():
diff += process_file(modular_file_path, generated_modeling_content, file_type, fix_and_overwrite)
diff += process_file(modular_file_path, generated_modeling_content, file_type, show_diff)
return diff
@@ -120,16 +134,20 @@ if __name__ == "__main__":
parser.add_argument(
"--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found."
)
parser.add_argument("--check_all", action="store_true", help="Check all files, not just the ones in the diff.")
parser.add_argument(
"--num_workers",
default=1,
default=-1,
type=int,
help="The number of workers to run. No effect if `fix_and_overwrite` is specified.",
help="The number of workers to run. Default is -1, which means the number of CPU cores.",
)
args = parser.parse_args()
if args.files == ["all"]:
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
if args.num_workers == -1:
args.num_workers = multiprocessing.cpu_count()
# Assuming there is a topological sort on the dependency mapping: if the file being checked and its dependencies
# are not in the diff, then there it is guaranteed to have no differences. If no models are in the diff, then this
# script will do nothing.
@@ -141,44 +159,71 @@ if __name__ == "__main__":
models_in_diff = {file_path.split("/")[-2] for file_path in args.files}
else:
models_in_diff = get_models_in_diff()
if not models_in_diff:
if not models_in_diff and not args.check_all:
console.print(
"[bold green]No models files or model tests in the diff, skipping modular checks[/bold green]"
)
exit(0)
skipped_models = set()
non_matching_files = 0
non_matching_files = []
ordered_files, dependencies = find_priority_list(args.files)
if args.fix_and_overwrite or args.num_workers == 1:
for modular_file_path in ordered_files:
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
if is_guaranteed_no_diff:
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
skipped_models.add(model_name)
flat_ordered_files = [item for sublist in ordered_files for item in sublist]
# ordered_files is a *sorted* list of lists of filepaths
# - files from the first list do NOT depend on other files
# - files in the second list depend on files from the first list
# - files in the third list depend on files from the second and (optionally) the first list
# - ... and so on
# files (models) within the same list are *independent* of each other;
# we start applying modular conversion to each list in parallel, starting from the first list
console.print(f"[bold yellow]Number of dependency levels: {len(ordered_files)}[/bold yellow]")
console.print(f"[bold yellow]Files per level: {tuple([len(x) for x in ordered_files])}[/bold yellow]")
try:
for dependency_level_files in ordered_files:
# Filter files guaranteed no diff
files_to_check = []
for file_path in dependency_level_files:
if not args.check_all and guaranteed_no_diff(file_path, dependencies, models_in_diff):
skipped_models.add(file_path.split("/")[-2]) # save model folder name
else:
files_to_check.append(file_path)
if not files_to_check:
continue
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
if current_branch != "main":
models_in_diff = get_models_in_diff() # When overwriting, the diff changes
else:
new_ordered_files = []
for modular_file_path in ordered_files:
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
if is_guaranteed_no_diff:
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
skipped_models.add(model_name)
else:
new_ordered_files.append(modular_file_path)
import multiprocessing
# Process files with diff
num_workers = min(args.num_workers, len(files_to_check))
with multiprocessing.Pool(num_workers) as p:
is_changed_flags = p.map(
partial(compare_files, show_diff=not args.fix_and_overwrite),
files_to_check,
)
with multiprocessing.Pool(args.num_workers) as p:
outputs = p.map(compare_files, new_ordered_files)
for output in outputs:
non_matching_files += output
# Collect changed files and their original paths
for is_changed, file_path in zip(is_changed_flags, files_to_check):
if is_changed:
non_matching_files.append(file_path)
# Update changed models, after each round of conversions
# (save model folder name)
models_in_diff.add(file_path.split("/")[-2])
finally:
# Restore overwritten files by modular (if needed)
backup_files = glob.glob("**/*" + BACKUP_EXT, recursive=True)
for backup_file_path in backup_files:
overwritten_path = backup_file_path.replace(BACKUP_EXT, "")
if not args.fix_and_overwrite and os.path.exists(overwritten_path):
shutil.copy(backup_file_path, overwritten_path)
os.remove(backup_file_path)
if non_matching_files and not args.fix_and_overwrite:
raise ValueError("Some diff and their modeling code did not match.")
diff_models = set(file_path.split("/")[-2] for file_path in non_matching_files) # noqa
models_str = "\n - " + "\n - ".join(sorted(diff_models))
raise ValueError(f"Some diff and their modeling code did not match. Models in diff:{models_str}")
if skipped_models:
console.print(