Document the test fetcher (#25521)
* Document the test fetcher * Address review comments
This commit is contained in:
@@ -18,20 +18,34 @@ Welcome to tests_fetcher V2.
|
|||||||
This util is designed to fetch tests to run on a PR so that only the tests impacted by the modifications are run, and
|
This util is designed to fetch tests to run on a PR so that only the tests impacted by the modifications are run, and
|
||||||
when too many models are being impacted, only run the tests of a subset of core models. It works like this.
|
when too many models are being impacted, only run the tests of a subset of core models. It works like this.
|
||||||
|
|
||||||
Stage 1: Identify the modified files. This takes all the files from the branching point to the current commit (so
|
Stage 1: Identify the modified files. For jobs that run on the main branch, it's just the diff with the last commit.
|
||||||
all modifications in a PR, not just the last commit) but excludes modifications that are on docstrings or comments
|
On a PR, this takes all the files from the branching point to the current commit (so all modifications in a PR, not
|
||||||
only.
|
just the last commit) but excludes modifications that are on docstrings or comments only.
|
||||||
|
|
||||||
Stage 2: Extract the tests to run. This is done by looking at the imports in each module and test file: if module A
|
Stage 2: Extract the tests to run. This is done by looking at the imports in each module and test file: if module A
|
||||||
imports module B, then changing module B impacts module A, so the tests using module A should be run. We thus get the
|
imports module B, then changing module B impacts module A, so the tests using module A should be run. We thus get the
|
||||||
dependencies of each model and then recursively builds the 'reverse' map of dependencies to get all modules and tests
|
dependencies of each model and then recursively builds the 'reverse' map of dependencies to get all modules and tests
|
||||||
impacted by a given file. We then only keep the tests (and only the code models tests if there are too many modules).
|
impacted by a given file. We then only keep the tests (and only the core models tests if there are too many modules).
|
||||||
|
|
||||||
Caveats:
|
Caveats:
|
||||||
- This module only filters tests by files (not individual tests) so it's better to have tests for different things
|
- This module only filters tests by files (not individual tests) so it's better to have tests for different things
|
||||||
in different files.
|
in different files.
|
||||||
- This module assumes inits are just importing things, not really building objects, so it's better to structure
|
- This module assumes inits are just importing things, not really building objects, so it's better to structure
|
||||||
them this way and move objects building in separate submodules.
|
them this way and move objects building in separate submodules.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
Base use to fetch the tests in a pull request
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python utils/tests_fetcher.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Base use to fetch the tests on a the main branch (with diff from the last commit):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python utils/tests_fetcher.py --diff_with_last_commit
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -41,6 +55,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from git import Repo
|
from git import Repo
|
||||||
|
|
||||||
@@ -80,9 +95,13 @@ IMPORTANT_MODELS = [
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def checkout_commit(repo, commit_id):
|
def checkout_commit(repo: Repo, commit_id: str):
|
||||||
"""
|
"""
|
||||||
Context manager that checks out a commit in the repo.
|
Context manager that checks out a given commit when entered, but gets back to the reference it was at on exit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo (`git.Repo`): A git repository (for instance the Transformers repo).
|
||||||
|
commit_id (`str`): The commit reference to checkout inside the context manager.
|
||||||
"""
|
"""
|
||||||
current_head = repo.head.commit if repo.head.is_detached else repo.head.ref
|
current_head = repo.head.commit if repo.head.is_detached else repo.head.ref
|
||||||
|
|
||||||
@@ -94,10 +113,19 @@ def checkout_commit(repo, commit_id):
|
|||||||
repo.git.checkout(current_head)
|
repo.git.checkout(current_head)
|
||||||
|
|
||||||
|
|
||||||
def clean_code(content):
|
def clean_code(content: str) -> str:
|
||||||
"""
|
"""
|
||||||
Remove docstrings, empty line or comments from `content`.
|
Remove docstrings, empty line or comments from some code (used to detect if a diff is real or only concern
|
||||||
|
comments or docstings).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content (`str`): The code to clean
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str`: The cleaned code.
|
||||||
"""
|
"""
|
||||||
|
# We need to deactivate autoformatting here to write escaped triple quotes (we cannot use real triple quotes or
|
||||||
|
# this would mess up the result if this function applied to this particular file).
|
||||||
# fmt: off
|
# fmt: off
|
||||||
# Remove docstrings by splitting on triple " then triple ':
|
# Remove docstrings by splitting on triple " then triple ':
|
||||||
splits = content.split('\"\"\"')
|
splits = content.split('\"\"\"')
|
||||||
@@ -111,15 +139,22 @@ def clean_code(content):
|
|||||||
for line in content.split("\n"):
|
for line in content.split("\n"):
|
||||||
# remove anything that is after a # sign.
|
# remove anything that is after a # sign.
|
||||||
line = re.sub("#.*$", "", line)
|
line = re.sub("#.*$", "", line)
|
||||||
if len(line) == 0 or line.isspace():
|
# remove white lines
|
||||||
continue
|
if len(line) != 0 and not line.isspace():
|
||||||
lines_to_keep.append(line)
|
lines_to_keep.append(line)
|
||||||
return "\n".join(lines_to_keep)
|
return "\n".join(lines_to_keep)
|
||||||
|
|
||||||
|
|
||||||
def keep_doc_examples_only(content):
|
def keep_doc_examples_only(content: str) -> str:
|
||||||
"""
|
"""
|
||||||
Remove code, docstring that is not code example, empty line or comments from `content`.
|
Remove everything from the code content except the doc examples (used to determined if a diff should trigger doc
|
||||||
|
tests or not).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content (`str`): The code to clean
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str`: The cleaned code.
|
||||||
"""
|
"""
|
||||||
# Keep doc examples only by splitting on triple "`"
|
# Keep doc examples only by splitting on triple "`"
|
||||||
splits = content.split("```")
|
splits = content.split("```")
|
||||||
@@ -131,17 +166,18 @@ def keep_doc_examples_only(content):
|
|||||||
for line in content.split("\n"):
|
for line in content.split("\n"):
|
||||||
# remove anything that is after a # sign.
|
# remove anything that is after a # sign.
|
||||||
line = re.sub("#.*$", "", line)
|
line = re.sub("#.*$", "", line)
|
||||||
if len(line) == 0 or line.isspace():
|
# remove white lines
|
||||||
continue
|
if len(line) != 0 and not line.isspace():
|
||||||
lines_to_keep.append(line)
|
lines_to_keep.append(line)
|
||||||
return "\n".join(lines_to_keep)
|
return "\n".join(lines_to_keep)
|
||||||
|
|
||||||
|
|
||||||
def get_all_tests():
|
def get_all_tests() -> List[str]:
|
||||||
"""
|
"""
|
||||||
Return a list of paths to all test folders and files under `tests`. All paths are rooted at `tests`.
|
Walks the `tests` folder to return a list of files/subfolders. This is used to split the tests to run when using
|
||||||
|
paralellism. The split is:
|
||||||
|
|
||||||
- folders under `tests`: `tokenization`, `pipelines`, etc. The folder `models` is excluded.
|
- folders under `tests`: (`tokenization`, `pipelines`, etc) except the subfolder `models` is excluded.
|
||||||
- folders under `tests/models`: `bert`, `gpt2`, etc.
|
- folders under `tests/models`: `bert`, `gpt2`, etc.
|
||||||
- test files under `tests`: `test_modeling_common.py`, `test_tokenization_common.py`, etc.
|
- test files under `tests`: `test_modeling_common.py`, `test_tokenization_common.py`, etc.
|
||||||
"""
|
"""
|
||||||
@@ -165,9 +201,17 @@ def get_all_tests():
|
|||||||
return tests
|
return tests
|
||||||
|
|
||||||
|
|
||||||
def diff_is_docstring_only(repo, branching_point, filename):
|
def diff_is_docstring_only(repo: Repo, branching_point: str, filename: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the diff is only in docstrings in a filename.
|
Check if the diff is only in docstrings (or comments and whitespace) in a filename.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo (`git.Repo`): A git repository (for instance the Transformers repo).
|
||||||
|
branching_point (`str`): The commit reference of where to compare for the diff.
|
||||||
|
filename (`str`): The filename where we want to know if the diff isonly in docstrings/comments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`bool`: Whether the diff is docstring/comments only or not.
|
||||||
"""
|
"""
|
||||||
folder = Path(repo.working_dir)
|
folder = Path(repo.working_dir)
|
||||||
with checkout_commit(repo, branching_point):
|
with checkout_commit(repo, branching_point):
|
||||||
@@ -183,9 +227,17 @@ def diff_is_docstring_only(repo, branching_point, filename):
|
|||||||
return old_content_clean == new_content_clean
|
return old_content_clean == new_content_clean
|
||||||
|
|
||||||
|
|
||||||
def diff_contains_doc_examples(repo, branching_point, filename):
|
def diff_contains_doc_examples(repo: Repo, branching_point: str, filename: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the diff is only in code in a filename.
|
Check if the diff is only in code examples of the doc in a filename.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo (`git.Repo`): A git repository (for instance the Transformers repo).
|
||||||
|
branching_point (`str`): The commit reference of where to compare for the diff.
|
||||||
|
filename (`str`): The filename where we want to know if the diff is only in codes examples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`bool`: Whether the diff is only in code examples of the doc or not.
|
||||||
"""
|
"""
|
||||||
folder = Path(repo.working_dir)
|
folder = Path(repo.working_dir)
|
||||||
with checkout_commit(repo, branching_point):
|
with checkout_commit(repo, branching_point):
|
||||||
@@ -201,9 +253,22 @@ def diff_contains_doc_examples(repo, branching_point, filename):
|
|||||||
return old_content_clean != new_content_clean
|
return old_content_clean != new_content_clean
|
||||||
|
|
||||||
|
|
||||||
def get_diff(repo, base_commit, commits):
|
def get_diff(repo: Repo, base_commit: str, commits: List[str]) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get's the diff between one or several commits and the head of the repository.
|
Get the diff between a base commit and one or several commits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo (`git.Repo`):
|
||||||
|
A git repository (for instance the Transformers repo).
|
||||||
|
base_commit (`str`):
|
||||||
|
The commit reference of where to compare for the diff. This is the current commit, not the branching point!
|
||||||
|
commits (`List[str]`):
|
||||||
|
The list of commits with which to compare the repo at `base_commit` (so the branching point).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The list of Python files with a diff (files added, renamed or deleted are always returned, files
|
||||||
|
modified are returned if the diff in the file is not only in docstrings or comments, see
|
||||||
|
`diff_is_docstring_only`).
|
||||||
"""
|
"""
|
||||||
print("\n### DIFF ###\n")
|
print("\n### DIFF ###\n")
|
||||||
code_diff = []
|
code_diff = []
|
||||||
@@ -230,12 +295,17 @@ def get_diff(repo, base_commit, commits):
|
|||||||
return code_diff
|
return code_diff
|
||||||
|
|
||||||
|
|
||||||
def get_modified_python_files(diff_with_last_commit=False):
|
def get_modified_python_files(diff_with_last_commit: bool = False) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Return a list of python files that have been modified between:
|
Return a list of python files that have been modified between:
|
||||||
|
|
||||||
- the current head and the main branch if `diff_with_last_commit=False` (default)
|
- the current head and the main branch if `diff_with_last_commit=False` (default)
|
||||||
- the current head and its parent commit otherwise.
|
- the current head and its parent commit otherwise.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The list of Python files with a diff (files added, renamed or deleted are always returned, files
|
||||||
|
modified are returned if the diff in the file is not only in docstrings or comments, see
|
||||||
|
`diff_is_docstring_only`).
|
||||||
"""
|
"""
|
||||||
repo = Repo(PATH_TO_REPO)
|
repo = Repo(PATH_TO_REPO)
|
||||||
|
|
||||||
@@ -255,23 +325,34 @@ def get_modified_python_files(diff_with_last_commit=False):
|
|||||||
return get_diff(repo, repo.head.commit, parent_commits)
|
return get_diff(repo, repo.head.commit, parent_commits)
|
||||||
|
|
||||||
|
|
||||||
def get_diff_for_doctesting(repo, base_commit, commits):
|
def get_diff_for_doctesting(repo: Repo, base_commit: str, commits: List[str]) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get's the diff between one or several commits and the head of the repository where some doc example(s) are changed.
|
Get the diff in doc examples between a base commit and one or several commits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo (`git.Repo`):
|
||||||
|
A git repository (for instance the Transformers repo).
|
||||||
|
base_commit (`str`):
|
||||||
|
The commit reference of where to compare for the diff. This is the current commit, not the branching point!
|
||||||
|
commits (`List[str]`):
|
||||||
|
The list of commits with which to compare the repo at `base_commit` (so the branching point).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The list of Python and Markdown files with a diff (files added or renamed are always returned, files
|
||||||
|
modified are returned if the diff in the file is only in doctest examples).
|
||||||
"""
|
"""
|
||||||
print("\n### DIFF ###\n")
|
print("\n### DIFF ###\n")
|
||||||
code_diff = []
|
code_diff = []
|
||||||
for commit in commits:
|
for commit in commits:
|
||||||
for diff_obj in commit.diff(base_commit):
|
for diff_obj in commit.diff(base_commit):
|
||||||
|
# We only consider Python files and doc files.
|
||||||
|
if not diff_obj.b_path.endswith(".py") and not diff_obj.b_path.endswith(".md"):
|
||||||
|
continue
|
||||||
# We always add new python/md files
|
# We always add new python/md files
|
||||||
if diff_obj.change_type in ["A"] and (diff_obj.b_path.endswith(".py") or diff_obj.b_path.endswith(".md")):
|
if diff_obj.change_type in ["A"]:
|
||||||
code_diff.append(diff_obj.b_path)
|
code_diff.append(diff_obj.b_path)
|
||||||
# Now for modified files
|
# Now for modified files
|
||||||
elif (
|
elif diff_obj.change_type in ["M", "R"]:
|
||||||
diff_obj.change_type in ["M", "R"]
|
|
||||||
and diff_obj.b_path.endswith(".py")
|
|
||||||
or diff_obj.b_path.endswith(".md")
|
|
||||||
):
|
|
||||||
# In case of renames, we'll look at the tests using both the old and new name.
|
# In case of renames, we'll look at the tests using both the old and new name.
|
||||||
if diff_obj.a_path != diff_obj.b_path:
|
if diff_obj.a_path != diff_obj.b_path:
|
||||||
code_diff.extend([diff_obj.a_path, diff_obj.b_path])
|
code_diff.extend([diff_obj.a_path, diff_obj.b_path])
|
||||||
@@ -285,12 +366,16 @@ def get_diff_for_doctesting(repo, base_commit, commits):
|
|||||||
return code_diff
|
return code_diff
|
||||||
|
|
||||||
|
|
||||||
def get_doctest_files(diff_with_last_commit=False):
|
def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Return a list of python and mdx files where some doc example(s) in them have been modified between:
|
Return a list of python and Markdown files where doc example have been modified between:
|
||||||
|
|
||||||
- the current head and the main branch if `diff_with_last_commit=False` (default)
|
- the current head and the main branch if `diff_with_last_commit=False` (default)
|
||||||
- the current head and its parent commit otherwise.
|
- the current head and its parent commit otherwise.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The list of Python and Markdown files with a diff (files added or renamed are always returned, files
|
||||||
|
modified are returned if the diff in the file is only in doctest examples).
|
||||||
"""
|
"""
|
||||||
repo = Repo(PATH_TO_REPO)
|
repo = Repo(PATH_TO_REPO)
|
||||||
|
|
||||||
@@ -313,7 +398,7 @@ def get_doctest_files(diff_with_last_commit=False):
|
|||||||
# This is the full list of doctest tests
|
# This is the full list of doctest tests
|
||||||
with open("utils/documentation_tests.txt") as fp:
|
with open("utils/documentation_tests.txt") as fp:
|
||||||
documentation_tests = set(fp.read().strip().split("\n"))
|
documentation_tests = set(fp.read().strip().split("\n"))
|
||||||
# Not to run slow doctest tests
|
# Do not run slow doctest tests
|
||||||
with open("utils/slow_documentation_tests.txt") as fp:
|
with open("utils/slow_documentation_tests.txt") as fp:
|
||||||
slow_documentation_tests = set(fp.read().strip().split("\n"))
|
slow_documentation_tests = set(fp.read().strip().split("\n"))
|
||||||
|
|
||||||
@@ -348,10 +433,21 @@ _re_single_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*
|
|||||||
_re_multi_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*)\s+import\s+\(([^\)]+)\)")
|
_re_multi_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*)\s+import\s+\(([^\)]+)\)")
|
||||||
|
|
||||||
|
|
||||||
def extract_imports(module_fname, cache=None):
|
def extract_imports(module_fname: str, cache: Dict[str, List[str]] = None) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get the imports a given module makes. This takes a module filename and returns the list of module filenames
|
Get the imports a given module makes.
|
||||||
imported in the module with the objects imported in that module filename.
|
|
||||||
|
Args:
|
||||||
|
module_fname (`str`):
|
||||||
|
The name of the file of the module where we want to look at the imports (given relative to the root of
|
||||||
|
the repo).
|
||||||
|
cache (Dictionary `str` to `List[str]`, *optional*):
|
||||||
|
To speed up this function if it was previously called on `module_fname`, the cache of all previously
|
||||||
|
computed results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The list of module filenames imported in the input `module_fname` (a submodule we import from that
|
||||||
|
is a subfolder will give its init file).
|
||||||
"""
|
"""
|
||||||
if cache is not None and module_fname in cache:
|
if cache is not None and module_fname in cache:
|
||||||
return cache[module_fname]
|
return cache[module_fname]
|
||||||
@@ -359,7 +455,8 @@ def extract_imports(module_fname, cache=None):
|
|||||||
with open(PATH_TO_REPO / module_fname, "r", encoding="utf-8") as f:
|
with open(PATH_TO_REPO / module_fname, "r", encoding="utf-8") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# Filter out all docstrings to not get imports in code examples.
|
# Filter out all docstrings to not get imports in code examples. As before we need to deactivate formatting to
|
||||||
|
# keep this as escaped quotes and avoid this function failing on this file.
|
||||||
# fmt: off
|
# fmt: off
|
||||||
splits = content.split('\"\"\"')
|
splits = content.split('\"\"\"')
|
||||||
# fmt: on
|
# fmt: on
|
||||||
@@ -376,6 +473,7 @@ def extract_imports(module_fname, cache=None):
|
|||||||
multiline_relative_imports = _re_multi_line_relative_imports.findall(content)
|
multiline_relative_imports = _re_multi_line_relative_imports.findall(content)
|
||||||
relative_imports += [(mod, imp) for mod, imp in multiline_relative_imports if "# tests_ignore" not in imp]
|
relative_imports += [(mod, imp) for mod, imp in multiline_relative_imports if "# tests_ignore" not in imp]
|
||||||
|
|
||||||
|
# We need to remove parts of the module name depending on the depth of the relative imports.
|
||||||
for module, imports in relative_imports:
|
for module, imports in relative_imports:
|
||||||
level = 0
|
level = 0
|
||||||
while module.startswith("."):
|
while module.startswith("."):
|
||||||
@@ -395,13 +493,15 @@ def extract_imports(module_fname, cache=None):
|
|||||||
multiline_direct_imports = _re_multi_line_direct_imports.findall(content)
|
multiline_direct_imports = _re_multi_line_direct_imports.findall(content)
|
||||||
direct_imports += [(mod, imp) for mod, imp in multiline_direct_imports if "# tests_ignore" not in imp]
|
direct_imports += [(mod, imp) for mod, imp in multiline_direct_imports if "# tests_ignore" not in imp]
|
||||||
|
|
||||||
|
# We need to find the relative path of those imports.
|
||||||
for module, imports in direct_imports:
|
for module, imports in direct_imports:
|
||||||
import_parts = module.split(".")[1:] # ignore the first .
|
import_parts = module.split(".")[1:] # ignore the name of the repo since we add it below.
|
||||||
dep_parts = ["src", "transformers"] + import_parts
|
dep_parts = ["src", "transformers"] + import_parts
|
||||||
imported_module = os.path.sep.join(dep_parts)
|
imported_module = os.path.sep.join(dep_parts)
|
||||||
imported_modules.append((imported_module, [imp.strip() for imp in imports.split(",")]))
|
imported_modules.append((imported_module, [imp.strip() for imp in imports.split(",")]))
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
|
# Double check we get proper modules (either a python file or a folder with an init).
|
||||||
for module_file, imports in imported_modules:
|
for module_file, imports in imported_modules:
|
||||||
if (PATH_TO_REPO / f"{module_file}.py").is_file():
|
if (PATH_TO_REPO / f"{module_file}.py").is_file():
|
||||||
module_file = f"{module_file}.py"
|
module_file = f"{module_file}.py"
|
||||||
@@ -417,15 +517,30 @@ def extract_imports(module_fname, cache=None):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_module_dependencies(module_fname, cache=None):
|
def get_module_dependencies(module_fname: str, cache: Dict[str, List[str]] = None) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get the dependencies of a module from the module filename as a list of module filenames. This will resolve any
|
Refines the result of `extract_imports` to remove subfolders and get a proper list of module filenames: if a file
|
||||||
__init__ we pass: if we import from a submodule utils, the dependencies will be utils/foo.py and utils/bar.py (if
|
as an import `from utils import Foo, Bar`, with `utils` being a subfolder containing many files, this will traverse
|
||||||
the objects imported actually come from utils.foo and utils.bar) not utils/__init__.py.
|
the `utils` init file to check where those dependencies come from: for instance the files utils/foo.py and utils/bar.py.
|
||||||
|
|
||||||
|
Warning: This presupposes that all intermediate inits are properly built (with imports from the respective
|
||||||
|
submodules) and work better if objects are defined in submodules and not the intermediate init (otherwise the
|
||||||
|
intermediate init is added, and inits usually have a lot of dependencies).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_fname (`str`):
|
||||||
|
The name of the file of the module where we want to look at the imports (given relative to the root of
|
||||||
|
the repo).
|
||||||
|
cache (Dictionary `str` to `List[str]`, *optional*):
|
||||||
|
To speed up this function if it was previously called on `module_fname`, the cache of all previously
|
||||||
|
computed results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The list of module filenames imported in the input `module_fname` (with submodule imports refined).
|
||||||
"""
|
"""
|
||||||
dependencies = []
|
dependencies = []
|
||||||
imported_modules = extract_imports(module_fname, cache=cache)
|
imported_modules = extract_imports(module_fname, cache=cache)
|
||||||
# The while loop is to recursively traverse all inits we may encounter.
|
# The while loop is to recursively traverse all inits we may encounter: we will add things as we go.
|
||||||
while len(imported_modules) > 0:
|
while len(imported_modules) > 0:
|
||||||
new_modules = []
|
new_modules = []
|
||||||
for module, imports in imported_modules:
|
for module, imports in imported_modules:
|
||||||
@@ -461,7 +576,7 @@ def get_module_dependencies(module_fname, cache=None):
|
|||||||
return dependencies
|
return dependencies
|
||||||
|
|
||||||
|
|
||||||
def create_reverse_dependency_tree():
|
def create_reverse_dependency_tree() -> List[Tuple[str, str]]:
|
||||||
"""
|
"""
|
||||||
Create a list of all edges (a, b) which mean that modifying a impacts b with a going over all module and test files.
|
Create a list of all edges (a, b) which mean that modifying a impacts b with a going over all module and test files.
|
||||||
"""
|
"""
|
||||||
@@ -473,10 +588,17 @@ def create_reverse_dependency_tree():
|
|||||||
return list(set(edges))
|
return list(set(edges))
|
||||||
|
|
||||||
|
|
||||||
def get_tree_starting_at(module, edges):
|
def get_tree_starting_at(module: str, edges: List[Tuple[str, str]]) -> List[Union[str, List[str]]]:
|
||||||
"""
|
"""
|
||||||
Returns the tree starting at a given module following all edges in the following format: [module, [list of edges
|
Returns the tree starting at a given module following all edges.
|
||||||
starting at module], [list of edges starting at the preceding level], ...]
|
|
||||||
|
Args:
|
||||||
|
module (`str`): The module that will be the root of the subtree we want.
|
||||||
|
eges (`List[Tuple[str, str]]`): The list of all edges of the tree.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[Union[str, List[str]]]`: The tree to print in the following format: [module, [list of edges
|
||||||
|
starting at module], [list of edges starting at the preceding level], ...]
|
||||||
"""
|
"""
|
||||||
vertices_seen = [module]
|
vertices_seen = [module]
|
||||||
new_edges = [edge for edge in edges if edge[0] == module and edge[1] != module and "__init__.py" not in edge[1]]
|
new_edges = [edge for edge in edges if edge[0] == module and edge[1] != module and "__init__.py" not in edge[1]]
|
||||||
@@ -497,6 +619,11 @@ def get_tree_starting_at(module, edges):
|
|||||||
def print_tree_deps_of(module, all_edges=None):
|
def print_tree_deps_of(module, all_edges=None):
|
||||||
"""
|
"""
|
||||||
Prints the tree of modules depending on a given module.
|
Prints the tree of modules depending on a given module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (`str`): The module that will be the root of the subtree we want.
|
||||||
|
all_eges (`List[Tuple[str, str]]`, *optional*):
|
||||||
|
The list of all edges of the tree. Will be set to `create_reverse_dependency_tree()` if not passed.
|
||||||
"""
|
"""
|
||||||
if all_edges is None:
|
if all_edges is None:
|
||||||
all_edges = create_reverse_dependency_tree()
|
all_edges = create_reverse_dependency_tree()
|
||||||
@@ -522,16 +649,24 @@ def print_tree_deps_of(module, all_edges=None):
|
|||||||
print(line[0])
|
print(line[0])
|
||||||
|
|
||||||
|
|
||||||
def init_test_examples_dependencies():
|
def init_test_examples_dependencies() -> Tuple[Dict[str, List[str]], List[str]]:
|
||||||
"""
|
"""
|
||||||
The test examples do not import from the examples (which are just scripts, not modules) so we need som extra
|
The test examples do not import from the examples (which are just scripts, not modules) so we need som extra
|
||||||
care initializing the dependency map there.
|
care initializing the dependency map, which is the goal of this function. It initializes the dependency map for
|
||||||
|
example files by linking each example to the example test file for the example framework.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[Dict[str, List[str]], List[str]]`: A tuple with two elements: the initialized dependency map which is a
|
||||||
|
dict test example file to list of example files potentially tested by that test file, and the list of all
|
||||||
|
example files (to avoid recomputing it later).
|
||||||
"""
|
"""
|
||||||
test_example_deps = {}
|
test_example_deps = {}
|
||||||
all_examples = []
|
all_examples = []
|
||||||
for framework in ["flax", "pytorch", "tensorflow"]:
|
for framework in ["flax", "pytorch", "tensorflow"]:
|
||||||
test_files = list((PATH_TO_EXAMPLES / framework).glob("test_*.py"))
|
test_files = list((PATH_TO_EXAMPLES / framework).glob("test_*.py"))
|
||||||
all_examples.extend(test_files)
|
all_examples.extend(test_files)
|
||||||
|
# Remove the files at the root of examples/framework since they are not proper examples (they are eith utils
|
||||||
|
# or example test files).
|
||||||
examples = [
|
examples = [
|
||||||
f for f in (PATH_TO_EXAMPLES / framework).glob("**/*.py") if f.parent != PATH_TO_EXAMPLES / framework
|
f for f in (PATH_TO_EXAMPLES / framework).glob("**/*.py") if f.parent != PATH_TO_EXAMPLES / framework
|
||||||
]
|
]
|
||||||
@@ -539,24 +674,33 @@ def init_test_examples_dependencies():
|
|||||||
for test_file in test_files:
|
for test_file in test_files:
|
||||||
with open(test_file, "r", encoding="utf-8") as f:
|
with open(test_file, "r", encoding="utf-8") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
# Map all examples to the test files found in examples/framework.
|
||||||
test_example_deps[str(test_file.relative_to(PATH_TO_REPO))] = [
|
test_example_deps[str(test_file.relative_to(PATH_TO_REPO))] = [
|
||||||
str(e.relative_to(PATH_TO_REPO)) for e in examples if e.name in content
|
str(e.relative_to(PATH_TO_REPO)) for e in examples if e.name in content
|
||||||
]
|
]
|
||||||
|
# Also map the test files to themselves.
|
||||||
test_example_deps[str(test_file.relative_to(PATH_TO_REPO))].append(
|
test_example_deps[str(test_file.relative_to(PATH_TO_REPO))].append(
|
||||||
str(test_file.relative_to(PATH_TO_REPO))
|
str(test_file.relative_to(PATH_TO_REPO))
|
||||||
)
|
)
|
||||||
return test_example_deps, all_examples
|
return test_example_deps, all_examples
|
||||||
|
|
||||||
|
|
||||||
def create_reverse_dependency_map():
|
def create_reverse_dependency_map() -> Dict[str, List[str]]:
|
||||||
"""
|
"""
|
||||||
Create the dependency map from module/test filename to the list of modules/tests that depend on it (even
|
Create the dependency map from module/test filename to the list of modules/tests that depend on it recursively.
|
||||||
recursively).
|
|
||||||
|
Returns:
|
||||||
|
`Dict[str, List[str]]`: The reverse dependency map as a dictionary mapping filenames to all the filenames
|
||||||
|
depending on it recursively. This way the tests impacted by a change in file A are the test files in the list
|
||||||
|
corresponding to key A in this result.
|
||||||
"""
|
"""
|
||||||
cache = {}
|
cache = {}
|
||||||
|
# Start from the example deps init.
|
||||||
example_deps, examples = init_test_examples_dependencies()
|
example_deps, examples = init_test_examples_dependencies()
|
||||||
|
# Add all modules and all tests to all examples
|
||||||
all_modules = list(PATH_TO_TRANFORMERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py")) + examples
|
all_modules = list(PATH_TO_TRANFORMERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py")) + examples
|
||||||
all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules]
|
all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules]
|
||||||
|
# Compute the direct dependencies of all modules.
|
||||||
direct_deps = {m: get_module_dependencies(m, cache=cache) for m in all_modules}
|
direct_deps = {m: get_module_dependencies(m, cache=cache) for m in all_modules}
|
||||||
direct_deps.update(example_deps)
|
direct_deps.update(example_deps)
|
||||||
|
|
||||||
@@ -566,6 +710,8 @@ def create_reverse_dependency_map():
|
|||||||
something_changed = False
|
something_changed = False
|
||||||
for m in all_modules:
|
for m in all_modules:
|
||||||
for d in direct_deps[m]:
|
for d in direct_deps[m]:
|
||||||
|
# We stop recursing at an init (cause we always end up in the main init and we don't want to add all
|
||||||
|
# files which the main init imports)
|
||||||
if d.endswith("__init__.py"):
|
if d.endswith("__init__.py"):
|
||||||
continue
|
continue
|
||||||
if d not in direct_deps:
|
if d not in direct_deps:
|
||||||
@@ -581,6 +727,8 @@ def create_reverse_dependency_map():
|
|||||||
for d in direct_deps[m]:
|
for d in direct_deps[m]:
|
||||||
reverse_map[d].append(m)
|
reverse_map[d].append(m)
|
||||||
|
|
||||||
|
# For inits, we don't do the reverse deps but the direct deps: if modifying an init, we want to make sure we test
|
||||||
|
# all the modules impacted by that init.
|
||||||
for m in [f for f in all_modules if f.endswith("__init__.py")]:
|
for m in [f for f in all_modules if f.endswith("__init__.py")]:
|
||||||
direct_deps = get_module_dependencies(m, cache=cache)
|
direct_deps = get_module_dependencies(m, cache=cache)
|
||||||
deps = sum([reverse_map[d] for d in direct_deps if not d.endswith("__init__.py")], direct_deps)
|
deps = sum([reverse_map[d] for d in direct_deps if not d.endswith("__init__.py")], direct_deps)
|
||||||
@@ -589,13 +737,26 @@ def create_reverse_dependency_map():
|
|||||||
return reverse_map
|
return reverse_map
|
||||||
|
|
||||||
|
|
||||||
def create_module_to_test_map(reverse_map=None, filter_models=False):
|
def create_module_to_test_map(
|
||||||
|
reverse_map: Dict[str, List[str]] = None, filter_models: bool = False
|
||||||
|
) -> Dict[str, List[str]]:
|
||||||
"""
|
"""
|
||||||
Extract the tests from the reverse_dependency_map and potentially filters the model tests.
|
Extract the tests from the reverse_dependency_map and potentially filters the model tests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reverse_map (`Dict[str, List[str]]`, *optional*):
|
||||||
|
The reverse dependency map as created by `create_reverse_dependency_map`. Will default to the result of
|
||||||
|
that function if not provided.
|
||||||
|
filter_models (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to filter model tests to only include core models if a file impacts a lot of models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Dict[str, List[str]]`: A dictionary that maps each file to the tests to execute if that file was modified.
|
||||||
"""
|
"""
|
||||||
if reverse_map is None:
|
if reverse_map is None:
|
||||||
reverse_map = create_reverse_dependency_map()
|
reverse_map = create_reverse_dependency_map()
|
||||||
|
|
||||||
|
# Utility that tells us if a given file is a test (taking test examples into account)
|
||||||
def is_test(fname):
|
def is_test(fname):
|
||||||
if fname.startswith("tests"):
|
if fname.startswith("tests"):
|
||||||
return True
|
return True
|
||||||
@@ -603,14 +764,17 @@ def create_module_to_test_map(reverse_map=None, filter_models=False):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Build the test map
|
||||||
test_map = {module: [f for f in deps if is_test(f)] for module, deps in reverse_map.items()}
|
test_map = {module: [f for f in deps if is_test(f)] for module, deps in reverse_map.items()}
|
||||||
|
|
||||||
if not filter_models:
|
if not filter_models:
|
||||||
return test_map
|
return test_map
|
||||||
|
|
||||||
|
# Now we deal with the filtering if `filter_models` is True.
|
||||||
num_model_tests = len(list(PATH_TO_TESTS.glob("models/*")))
|
num_model_tests = len(list(PATH_TO_TESTS.glob("models/*")))
|
||||||
|
|
||||||
def has_many_models(tests):
|
def has_many_models(tests):
|
||||||
|
# We filter to core models when a given file impacts more than half the model tests.
|
||||||
model_tests = {Path(t).parts[2] for t in tests if t.startswith("tests/models/")}
|
model_tests = {Path(t).parts[2] for t in tests if t.startswith("tests/models/")}
|
||||||
return len(model_tests) > num_model_tests // 2
|
return len(model_tests) > num_model_tests // 2
|
||||||
|
|
||||||
@@ -623,7 +787,7 @@ def create_module_to_test_map(reverse_map=None, filter_models=False):
|
|||||||
def check_imports_all_exist():
|
def check_imports_all_exist():
|
||||||
"""
|
"""
|
||||||
Isn't used per se by the test fetcher but might be used later as a quality check. Putting this here for now so the
|
Isn't used per se by the test fetcher but might be used later as a quality check. Putting this here for now so the
|
||||||
code is not lost.
|
code is not lost. This checks all imports in a given file do exist.
|
||||||
"""
|
"""
|
||||||
cache = {}
|
cache = {}
|
||||||
all_modules = list(PATH_TO_TRANFORMERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py"))
|
all_modules = list(PATH_TO_TRANFORMERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py"))
|
||||||
@@ -636,11 +800,21 @@ def check_imports_all_exist():
|
|||||||
print(f"{module} has dependency on {dep} which does not exist.")
|
print(f"{module} has dependency on {dep} which does not exist.")
|
||||||
|
|
||||||
|
|
||||||
def _print_list(l):
|
def _print_list(l) -> str:
|
||||||
|
"""
|
||||||
|
Pretty print a list of elements with one line per element and a - starting each line.
|
||||||
|
"""
|
||||||
return "\n".join([f"- {f}" for f in l])
|
return "\n".join([f"- {f}" for f in l])
|
||||||
|
|
||||||
|
|
||||||
def create_json_map(test_files_to_run, json_output_file):
|
def create_json_map(test_files_to_run: List[str], json_output_file: str):
|
||||||
|
"""
|
||||||
|
Creates a map from a list of tests to run to easily split them by category, when running parallelism of slow tests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
test_files_to_run (`List[str]`): The list of tests to run.
|
||||||
|
json_output_file (`str`): The path where to store the built json map.
|
||||||
|
"""
|
||||||
if json_output_file is None:
|
if json_output_file is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -673,7 +847,34 @@ def create_json_map(test_files_to_run, json_output_file):
|
|||||||
json.dump(test_map, fp, ensure_ascii=False)
|
json.dump(test_map, fp, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def infer_tests_to_run(output_file, diff_with_last_commit=False, filter_models=True, json_output_file=None):
|
def infer_tests_to_run(
|
||||||
|
output_file: str,
|
||||||
|
diff_with_last_commit: bool = False,
|
||||||
|
filter_models: bool = True,
|
||||||
|
json_output_file: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
The main function called by the test fetcher. Determines the tests to run from the diff.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_file (`str`):
|
||||||
|
The path where to store the summary of the test fetcher analysis. Other files will be stored in the same
|
||||||
|
folder:
|
||||||
|
|
||||||
|
- examples_test_list.txt: The list of examples tests to run.
|
||||||
|
- test_repo_utils.txt: Will indicate if the repo utils tests should be run or not.
|
||||||
|
- doctest_list.txt: The list of doctests to run.
|
||||||
|
|
||||||
|
diff_with_last_commit (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to analyze the diff with the last commit (for use on the main branch after a PR is merged) or with
|
||||||
|
the branching point from main (for use on each PR).
|
||||||
|
filter_models (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to filter the tests to core models only, when a file modified results in a lot of model
|
||||||
|
tests.
|
||||||
|
json_output_file (`str`, *optional*):
|
||||||
|
The path where to store the json file mapping categories of tests to tests to run (used for parallelism or
|
||||||
|
the slow tests).
|
||||||
|
"""
|
||||||
modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
|
modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
|
||||||
print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}")
|
print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}")
|
||||||
|
|
||||||
@@ -751,7 +952,7 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filter_models=T
|
|||||||
f.write(" ".join(doctest_list))
|
f.write(" ".join(doctest_list))
|
||||||
|
|
||||||
|
|
||||||
def filter_tests(output_file, filters):
|
def filter_tests(output_file: str, filters: List[str]):
|
||||||
"""
|
"""
|
||||||
Reads the content of the output file and filters out all the tests in a list of given folders.
|
Reads the content of the output file and filters out all the tests in a list of given folders.
|
||||||
|
|
||||||
@@ -778,11 +979,16 @@ def filter_tests(output_file, filters):
|
|||||||
f.write(" ".join(test_files))
|
f.write(" ".join(test_files))
|
||||||
|
|
||||||
|
|
||||||
def parse_commit_message(commit_message):
|
def parse_commit_message(commit_message: str) -> Dict[str, bool]:
|
||||||
"""
|
"""
|
||||||
Parses the commit message to detect if a command is there to skip, force all or part of the CI.
|
Parses the commit message to detect if a command is there to skip, force all or part of the CI.
|
||||||
|
|
||||||
Returns a dictionary of strings to bools with keys skip, test_all_models and test_all.
|
Args:
|
||||||
|
commit_message (`str`): The commit message of the current commit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Dict[str, bool]`: A dictionary of strings to bools with keys the following keys: `"skip"`,
|
||||||
|
`"test_all_models"` and `"test_all"`.
|
||||||
"""
|
"""
|
||||||
if commit_message is None:
|
if commit_message is None:
|
||||||
return {"skip": False, "no_filter": False, "test_all": False}
|
return {"skip": False, "no_filter": False, "test_all": False}
|
||||||
|
|||||||
Reference in New Issue
Block a user