[test fetcher] Always include the directly related test files (#30050)
* fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -958,10 +958,25 @@ def create_module_to_test_map(
|
||||
model_tests = {Path(t).parts[2] for t in tests if t.startswith("tests/models/")}
|
||||
return len(model_tests) > num_model_tests // 2
|
||||
|
||||
def filter_tests(tests):
|
||||
return [t for t in tests if not t.startswith("tests/models/") or Path(t).parts[2] in IMPORTANT_MODELS]
|
||||
# for each module (if specified in the argument `module`) of the form `models/my_model` (i.e. starting with it),
|
||||
# we always keep the tests (those are already in the argument `tests`) which are in `tests/models/my_model`.
|
||||
# This is to avoid them being excluded when a module has many impacted tests: the directly related test files should
|
||||
# always be included!
|
||||
def filter_tests(tests, module=""):
|
||||
return [
|
||||
t
|
||||
for t in tests
|
||||
if not t.startswith("tests/models/")
|
||||
or Path(t).parts[2] in IMPORTANT_MODELS
|
||||
# at this point, `t` is of the form `tests/models/my_model`, and we check if `models/my_model`
|
||||
# (i.e. `parts[1:3]`) is in `module`.
|
||||
or "/".join(Path(t).parts[1:3]) in module
|
||||
]
|
||||
|
||||
return {module: (filter_tests(tests) if has_many_models(tests) else tests) for module, tests in test_map.items()}
|
||||
return {
|
||||
module: (filter_tests(tests, module=module) if has_many_models(tests) else tests)
|
||||
for module, tests in test_map.items()
|
||||
}
|
||||
|
||||
|
||||
def check_imports_all_exist():
|
||||
|
||||
Reference in New Issue
Block a user