From 858ce6879a4aa7fa76a7c4e2ac20388e087ace26 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 30 May 2025 11:19:42 +0200 Subject: [PATCH] make it go brrrr (#38409) * make it go brrrr * date time * update * fix * up * uppp * up * no number i * udpate * fix * [paligemma] fix processor with suffix (#38365) fix pg processor * [video utils] group and reorder by number of frames (#38374) fix * Fix convert to original state dict for VLMs (#38385) * fix convert to original state dict * fix * lint * Update modeling_utils.py * update * warn * no verbose * fginal * ouft * style --------- Co-authored-by: Raushan Turganbay Co-authored-by: hoshi-hiyouga --- utils/patch_helper.py | 150 +++++++++++++++++++++++++++++------------- 1 file changed, 105 insertions(+), 45 deletions(-) diff --git a/utils/patch_helper.py b/utils/patch_helper.py index f31e755efd..9b7b4b0b56 100644 --- a/utils/patch_helper.py +++ b/utils/patch_helper.py @@ -35,61 +35,121 @@ git cherry-pick 0bef4a273825d2cfc52ddfe62ba486ee61cc116f #2024-05-29 13:33:26+01 ``` """ -import argparse +import json +import subprocess -from git import GitCommandError, Repo -from packaging import version +import transformers -def get_merge_commit(repo, pr_number, since_tag): +LABEL = "for patch" # Replace with your label +REPO = "huggingface/transformers" # Optional if already in correct repo + + +def get_release_branch_name(): + """Derive branch name from transformers version.""" + major, minor, *_ = transformers.__version__.split(".") + major = int(major) + minor = int(minor) + + if minor == 0: + # Handle major version rollback, e.g., from 5.0 to 4.latest (if ever needed) + major -= 1 + # You'll need logic to determine the last minor of the previous major version + raise ValueError("Minor version is 0; need logic to find previous major version's last minor") + else: + minor -= 1 + + return f"v{major}.{minor}-release" + + +def checkout_branch(branch): + """Checkout the target branch.""" try: - # Use git log to find the merge commit for the PR within the given tag range - merge_commit = next(repo.iter_commits(f"v{since_tag}...origin/main", grep=f"#{pr_number}")) - return merge_commit - except StopIteration: - print(f"No merge commit found for PR #{pr_number} between tags {since_tag} and {main}") - return None - except GitCommandError as e: - print(f"Error finding merge commit for PR #{pr_number}: {str(e)}") - return None + subprocess.run(["git", "checkout", branch], check=True) + print(f"✅ Checked out branch: {branch}") + except subprocess.CalledProcessError: + print(f"❌ Failed to checkout branch: {branch}. Does it exist?") + exit(1) -def main(pr_numbers): - repo = Repo(".") # Initialize the Repo object for the current directory - merge_commits = [] +def get_prs_by_label(label): + """Call gh CLI to get PRs with a specific label.""" + cmd = [ + "gh", + "pr", + "list", + "--label", + label, + "--state", + "all", + "--json", + "number,title,mergeCommit,url", + "--limit", + "100", + ] + result = subprocess.run(cmd, capture_output=True, text=True) + result.check_returncode() + prs = json.loads(result.stdout) + for pr in prs: + is_merged = pr.get("mergeCommit", {}) + if is_merged: + pr["oid"] = is_merged.get("oid") + return prs - tags = {} - for tag in repo.tags: - try: - # Parse and sort tags, skip invalid ones - tag_ver = version.parse(tag.name) - tags[tag_ver] = tag - except Exception: - print(f"Skipping invalid version tag: {tag.name}") - last_tag = sorted(tags)[-1] - major_minor = f"{last_tag.major}.{last_tag.minor}.0" - # Iterate through tag ranges to find the merge commits - for pr in pr_numbers: - pr = pr.split("https://github.com/huggingface/transformers/pull/")[-1] - commit = get_merge_commit(repo, pr, major_minor) - if commit: - merge_commits.append(commit) +def get_commit_timestamp(commit_sha): + """Get UNIX timestamp of a commit using git.""" + result = subprocess.run(["git", "show", "-s", "--format=%ct", commit_sha], capture_output=True, text=True) + result.check_returncode() + return int(result.stdout.strip()) - # Sort commits by date - merge_commits.sort(key=lambda commit: commit.committed_datetime) - # Output the git cherry-pick commands - print("Git cherry-pick commands to run:") - for commit in merge_commits: - print(f"git cherry-pick {commit.hexsha} #{commit.committed_datetime}") +def cherry_pick_commit(sha): + """Cherry-pick a given commit SHA.""" + try: + subprocess.run(["git", "cherry-pick", sha], check=True) + print(f"✅ Cherry-picked commit {sha}") + except subprocess.CalledProcessError: + print(f"⚠️ Failed to cherry-pick {sha}. Manual intervention required.") + + +def commit_in_history(commit_sha, base_branch="HEAD"): + """Return True if commit is already part of base_branch history.""" + result = subprocess.run( + ["git", "merge-base", "--is-ancestor", commit_sha, base_branch], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + return result.returncode == 0 + + +def main(verbose=False): + branch = get_release_branch_name() + checkout_branch(branch) + prs = get_prs_by_label(LABEL) + # Attach commit timestamps + for pr in prs: + sha = pr.get("oid") + if sha: + pr["timestamp"] = get_commit_timestamp(sha) + else: + print("\n" + "=" * 80) + print(f"⚠️ WARNING: PR #{pr['number']} ({sha}) is NOT in main!") + print("⚠️ A core maintainer must review this before cherry-picking.") + print("=" * 80 + "\n") + # Sort by commit timestamp (ascending) + prs = [pr for pr in prs if pr.get("timestamp") is not None] + prs.sort(key=lambda pr: pr["timestamp"]) + for pr in prs: + sha = pr.get("oid") + if sha: + if commit_in_history(sha): + if verbose: + print(f"🔁 PR #{pr['number']} ({pr['title']}) already in history. Skipping.") + else: + print(f"🚀 PR #{pr['number']} ({pr['title']}) not in history. Cherry-picking...") + cherry_pick_commit(sha) if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Find and sort merge commits for specified PRs.") - parser.add_argument("--prs", nargs="+", required=False, type=str, help="PR numbers to find merge commits for") - - args = parser.parse_args() - if args.prs is None: - args.prs = "https://github.com/huggingface/transformers/pull/33753 https://github.com/huggingface/transformers/pull/33861 https://github.com/huggingface/transformers/pull/33906 https://github.com/huggingface/transformers/pull/33761 https://github.com/huggingface/transformers/pull/33586 https://github.com/huggingface/transformers/pull/33766 https://github.com/huggingface/transformers/pull/33958 https://github.com/huggingface/transformers/pull/33965".split() - main(args.prs) + main()