From 48795317a21e9128d3ca877657acd855e9ba8477 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 5 Apr 2024 14:30:36 +0200 Subject: [PATCH] [test fetcher] Always include the directly related test files (#30050) * fix * fix --------- Co-authored-by: ydshieh --- utils/tests_fetcher.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index 6cc22cc5f1..e54e6d0de4 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -958,10 +958,25 @@ def create_module_to_test_map( model_tests = {Path(t).parts[2] for t in tests if t.startswith("tests/models/")} return len(model_tests) > num_model_tests // 2 - def filter_tests(tests): - return [t for t in tests if not t.startswith("tests/models/") or Path(t).parts[2] in IMPORTANT_MODELS] + # for each module (if specified in the argument `module`) of the form `models/my_model` (i.e. starting with it), + # we always keep the tests (those are already in the argument `tests`) which are in `tests/models/my_model`. + # This is to avoid them being excluded when a module has many impacted tests: the directly related test files should + # always be included! + def filter_tests(tests, module=""): + return [ + t + for t in tests + if not t.startswith("tests/models/") + or Path(t).parts[2] in IMPORTANT_MODELS + # at this point, `t` is of the form `tests/models/my_model`, and we check if `models/my_model` + # (i.e. `parts[1:3]`) is in `module`. + or "/".join(Path(t).parts[1:3]) in module + ] - return {module: (filter_tests(tests) if has_many_models(tests) else tests) for module, tests in test_map.items()} + return { + module: (filter_tests(tests, module=module) if has_many_models(tests) else tests) + for module, tests in test_map.items() + } def check_imports_all_exist():