From 87e2ea33aab6454be3afbd4f0342b518f15bccef Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 18 Mar 2024 14:32:42 +0100 Subject: [PATCH] Fix `filter_models` (#29710) * update * update * update * check --------- Co-authored-by: ydshieh --- utils/tests_fetcher.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index 21e19fd7d1..af4785fb6d 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -68,6 +68,10 @@ PATH_TO_EXAMPLES = PATH_TO_REPO / "examples" PATH_TO_TRANFORMERS = PATH_TO_REPO / "src/transformers" PATH_TO_TESTS = PATH_TO_REPO / "tests" +# The value is just a heuristic to determine if we `guess` all models are impacted. +# This variable has effect only if `filter_models=False`. +NUM_MODELS_TO_TRIGGER_FULL_CI = 30 + # List here the models to always test. IMPORTANT_MODELS = [ "auto", @@ -1064,10 +1068,18 @@ def infer_tests_to_run( impacted_files = sorted(set(impacted_files)) print(f"\n### IMPACTED FILES ###\n{_print_list(impacted_files)}") + model_impacted = {"/".join(x.split("/")[:3]) for x in impacted_files if x.startswith("tests/models/")} + # Grab the corresponding test files: if any(x in modified_files for x in ["setup.py", ".circleci/create_circleci_config.py"]): test_files_to_run = ["tests", "examples"] repo_utils_launch = True + elif not filter_models and len(model_impacted) >= NUM_MODELS_TO_TRIGGER_FULL_CI: + print( + f"More than {NUM_MODELS_TO_TRIGGER_FULL_CI - 1} models are impacted and `filter_models=False`. CI is configured to test everything." + ) + test_files_to_run = ["tests", "examples"] + repo_utils_launch = True else: # All modified tests need to be run. test_files_to_run = [ @@ -1244,7 +1256,7 @@ if __name__ == "__main__": args.output_file, diff_with_last_commit=diff_with_last_commit, json_output_file=args.json_output_file, - filter_models=(not commit_flags["no_filter"] or is_main_branch), + filter_models=(not (commit_flags["no_filter"] or is_main_branch)), ) filter_tests(args.output_file, ["repo_utils"]) except Exception as e: