Revamp test selection for the example tests (#23737)
* Revamp test selection for the example tests * Rename old XLA test and fake modif in run_glue * Fixes * Fake Trainer modif * Remove fake modifs
This commit is contained in:
@@ -46,6 +46,7 @@ from git import Repo
|
||||
|
||||
|
||||
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
|
||||
PATH_TO_EXAMPLES = PATH_TO_REPO / "examples"
|
||||
PATH_TO_TRANFORMERS = PATH_TO_REPO / "src/transformers"
|
||||
PATH_TO_TESTS = PATH_TO_REPO / "tests"
|
||||
|
||||
@@ -512,15 +513,40 @@ def print_tree_deps_of(module, all_edges=None):
|
||||
print(line[0])
|
||||
|
||||
|
||||
def init_test_examples_dependencies():
|
||||
"""
|
||||
The test examples do not import from the examples (which are just scripts, not modules) so we need som extra
|
||||
care initializing the dependency map there.
|
||||
"""
|
||||
test_example_deps = {}
|
||||
all_examples = []
|
||||
for framework in ["flax", "pytorch", "tensorflow"]:
|
||||
test_files = list((PATH_TO_EXAMPLES / framework).glob("test_*.py"))
|
||||
all_examples.extend(test_files)
|
||||
examples = [
|
||||
f for f in (PATH_TO_EXAMPLES / framework).glob("**/*.py") if f.parent != PATH_TO_EXAMPLES / framework
|
||||
]
|
||||
all_examples.extend(examples)
|
||||
for test_file in test_files:
|
||||
with open(test_file, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
test_example_deps[str(test_file.relative_to(PATH_TO_REPO))] = [
|
||||
str(e.relative_to(PATH_TO_REPO)) for e in examples if e.name in content
|
||||
]
|
||||
return test_example_deps, all_examples
|
||||
|
||||
|
||||
def create_reverse_dependency_map():
|
||||
"""
|
||||
Create the dependency map from module/test filename to the list of modules/tests that depend on it (even
|
||||
recursively).
|
||||
"""
|
||||
cache = {}
|
||||
all_modules = list(PATH_TO_TRANFORMERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py"))
|
||||
example_deps, examples = init_test_examples_dependencies()
|
||||
all_modules = list(PATH_TO_TRANFORMERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py")) + examples
|
||||
all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules]
|
||||
direct_deps = {m: get_module_dependencies(m, cache=cache) for m in all_modules}
|
||||
direct_deps.update(example_deps)
|
||||
|
||||
# This recurses the dependencies
|
||||
something_changed = True
|
||||
@@ -557,7 +583,15 @@ def create_module_to_test_map(reverse_map=None, filter_models=False):
|
||||
"""
|
||||
if reverse_map is None:
|
||||
reverse_map = create_reverse_dependency_map()
|
||||
test_map = {module: [f for f in deps if f.startswith("tests")] for module, deps in reverse_map.items()}
|
||||
|
||||
def is_test(fname):
|
||||
if fname.startswith("tests"):
|
||||
return True
|
||||
if fname.startswith("examples") and fname.split(os.path.sep)[-1].startswith("test"):
|
||||
return True
|
||||
return False
|
||||
|
||||
test_map = {module: [f for f in deps if is_test(f)] for module, deps in reverse_map.items()}
|
||||
|
||||
if not filter_models:
|
||||
return test_map
|
||||
@@ -627,9 +661,7 @@ def create_json_map(test_files_to_run, json_output_file):
|
||||
json.dump(test_map, fp, ensure_ascii=False)
|
||||
|
||||
|
||||
def infer_tests_to_run(
|
||||
output_file, diff_with_last_commit=False, filters=None, filter_models=True, json_output_file=None
|
||||
):
|
||||
def infer_tests_to_run(output_file, diff_with_last_commit=False, filter_models=True, json_output_file=None):
|
||||
modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
|
||||
print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}")
|
||||
|
||||
@@ -663,11 +695,6 @@ def infer_tests_to_run(
|
||||
test_files_to_run = [f for f in test_files_to_run if not f.split(os.path.sep)[1] == "sagemaker"]
|
||||
# Make sure we did not end up with a test file that was removed
|
||||
test_files_to_run = [f for f in test_files_to_run if (PATH_TO_REPO / f).exists()]
|
||||
if filters is not None:
|
||||
filtered_files = []
|
||||
for _filter in filters:
|
||||
filtered_files.extend([f for f in test_files_to_run if f.startswith(_filter)])
|
||||
test_files_to_run = filtered_files
|
||||
|
||||
repo_utils_launch = any(f.split(os.path.sep)[1] == "repo_utils" for f in modified_files)
|
||||
|
||||
@@ -676,6 +703,8 @@ def infer_tests_to_run(
|
||||
with open(repo_util_file, "w", encoding="utf-8") as f:
|
||||
f.write("tests/repo_utils")
|
||||
|
||||
examples_tests_to_run = [f for f in test_files_to_run if f.startswith("examples")]
|
||||
test_files_to_run = [f for f in test_files_to_run if not f.startswith("examples")]
|
||||
print(f"\n### TEST TO RUN ###\n{_print_list(test_files_to_run)}")
|
||||
if len(test_files_to_run) > 0:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
@@ -690,6 +719,12 @@ def infer_tests_to_run(
|
||||
|
||||
create_json_map(test_files_to_run, json_output_file)
|
||||
|
||||
print(f"\n### EXAMPLES TEST TO RUN ###\n{_print_list(examples_tests_to_run)}")
|
||||
if len(examples_tests_to_run) > 0:
|
||||
example_file = Path(output_file).parent / "examples_test_list.txt"
|
||||
with open(example_file, "w", encoding="utf-8") as f:
|
||||
f.write(" ".join(examples_tests_to_run))
|
||||
|
||||
doctest_list = get_doctest_files()
|
||||
|
||||
print(f"\n### DOCTEST TO RUN ###\n{_print_list(doctest_list)}")
|
||||
@@ -763,13 +798,6 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="To fetch the tests between the current commit and the last commit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filters",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["tests"],
|
||||
help="Only keep the test files matching one of those filters.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter_tests",
|
||||
action="store_true",
|
||||
@@ -814,7 +842,6 @@ if __name__ == "__main__":
|
||||
infer_tests_to_run(
|
||||
args.output_file,
|
||||
diff_with_last_commit=diff_with_last_commit,
|
||||
filters=args.filters,
|
||||
json_output_file=args.json_output_file,
|
||||
filter_models=not commit_flags["no_filter"],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user