Test fetch v2 (#22367)

* Test fetcher v2

* Fix regexes

* Remove sanity check

* Fake modification to OPT

* Fixes some .sep issues

* Remove fake OPT change

* Fake modif for BERT

* Fake modif for init

* Exclude SageMaker tests

* Fix test and remove fake modif

* Fake setup modif

* Fake pipeline modif

* Remove all fake modifs

* Adds options to skip/force tests

* [test-all-models] Fake modif for BERT

* Try this way

* Does the command actually work?

* [test-all-models] Try again!

* [skip circleci] Remove fake modif

* Remove debug statements

* Add the list of important models

* Quality

* Update utils/tests_fetcher.py

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>

* Address review comments

* Address review comments

* Fix and add test

* Apply suggestions from code review

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>

* Address review comments

---------

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger
2023-03-31 16:18:43 -04:00
committed by GitHub
parent 3a9464bd30
commit c612628045
4 changed files with 1005 additions and 412 deletions

View File

@@ -13,6 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Welcome to tests_fetcher V2.
This util is designed to fetch tests to run on a PR so that only the tests impacted by the modifications are run, and
when too many models are being impacted, only run the tests of a subset of core models. It works like this.
Stage 1: Identify the modified files. This takes all the files from the branching point to the current commit (so
all modifications in a PR, not just the last commit) but excludes modifications that are on docstrings or comments
only.
Stage 2: Extract the tests to run. This is done by looking at the imports in each module and test file: if module A
imports module B, then changing module B impacts module A, so the tests using module A should be run. We thus get the
dependencies of each model and then recursively builds the 'reverse' map of dependencies to get all modules and tests
impacted by a given file. We then only keep the tests (and only the code models tests if there are too many modules).
Caveats:
- This module only filters tests by files (not individual tests) so it's better to have tests for different things
in different files.
- This module assumes inits are just importing things, not really building objects, so it's better to structure
them this way and move objects building in separate submodules.
"""
import argparse
import collections
import json
@@ -24,13 +45,36 @@ from pathlib import Path
from git import Repo
# This script is intended to be run from the root of the repo but you can adapt this constant if you need to.
PATH_TO_TRANFORMERS = "."
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
PATH_TO_TRANFORMERS = PATH_TO_REPO / "src/transformers"
PATH_TO_TESTS = PATH_TO_REPO / "tests"
# A temporary way to trigger all pipeline tests contained in model test files after PR #21516
all_model_test_files = [str(x) for x in Path("tests/models/").glob("**/**/test_modeling_*.py")]
all_pipeline_test_files = [str(x) for x in Path("tests/pipelines/").glob("**/test_pipelines_*.py")]
# List here the models to always test.
IMPORTANT_MODELS = [
# Most downloaded models
"bert",
"clip",
"t5",
"xlm-roberta",
"gpt2",
"bart",
"mpnet",
"gpt-j",
"wav2vec2",
"deberta-v2",
"layoutlm",
"opt",
"longformer",
"vit",
# Pipeline-specific model (to be sure each pipeline has one model in this list)
"tapas",
"vilt",
"clap",
"detr",
"owlvit",
"dpt",
"videomae",
]
@contextmanager
@@ -79,17 +123,21 @@ def get_all_tests():
- folders under `tests/models`: `bert`, `gpt2`, etc.
- test files under `tests`: `test_modeling_common.py`, `test_tokenization_common.py`, etc.
"""
test_root_dir = os.path.join(PATH_TO_TRANFORMERS, "tests")
# test folders/files directly under `tests` folder
tests = os.listdir(test_root_dir)
tests = sorted(filter(lambda x: os.path.isdir(x) or x.startswith("tests/test_"), [f"tests/{x}" for x in tests]))
tests = os.listdir(PATH_TO_TESTS)
tests = [f"tests/{f}" for f in tests if "__pycache__" not in f]
tests = sorted([f for f in tests if (PATH_TO_REPO / f).is_dir() or f.startswith("tests/test_")])
# model specific test folders
model_tests_folders = os.listdir(os.path.join(test_root_dir, "models"))
model_test_folders = sorted(filter(os.path.isdir, [f"tests/models/{x}" for x in model_tests_folders]))
model_test_folders = os.listdir(PATH_TO_TESTS / "models")
model_test_folders = [f"tests/models/{f}" for f in model_test_folders if "__pycache__" not in f]
model_test_folders = sorted([f for f in model_test_folders if (PATH_TO_REPO / f).is_dir()])
tests.remove("tests/models")
# Sagemaker tests are not meant to be run on the CI.
if "tests/sagemaker" in tests:
tests.remove("tests/sagemaker")
tests = model_test_folders + tests
return tests
@@ -99,11 +147,12 @@ def diff_is_docstring_only(repo, branching_point, filename):
"""
Check if the diff is only in docstrings in a filename.
"""
folder = Path(repo.working_dir)
with checkout_commit(repo, branching_point):
with open(filename, "r", encoding="utf-8") as f:
with open(folder / filename, "r", encoding="utf-8") as f:
old_content = f.read()
with open(filename, "r", encoding="utf-8") as f:
with open(folder / filename, "r", encoding="utf-8") as f:
new_content = f.read()
old_content_clean = clean_code(old_content)
@@ -112,31 +161,6 @@ def diff_is_docstring_only(repo, branching_point, filename):
return old_content_clean == new_content_clean
def get_modified_python_files(diff_with_last_commit=False):
"""
Return a list of python files that have been modified between:
- the current head and the main branch if `diff_with_last_commit=False` (default)
- the current head and its parent commit otherwise.
"""
repo = Repo(PATH_TO_TRANFORMERS)
if not diff_with_last_commit:
print(f"main is at {repo.refs.main.commit}")
print(f"Current head is at {repo.head.commit}")
branching_commits = repo.merge_base(repo.refs.main, repo.head)
for commit in branching_commits:
print(f"Branching commit: {commit}")
return get_diff(repo, repo.head.commit, branching_commits)
else:
print(f"main is at {repo.head.commit}")
parent_commits = repo.head.commit.parents
for commit in parent_commits:
print(f"Parent commit: {commit}")
return get_diff(repo, repo.head.commit, parent_commits)
def get_diff(repo, base_commit, commits):
"""
Get's the diff between one or several commits and the head of the repository.
@@ -166,96 +190,173 @@ def get_diff(repo, base_commit, commits):
return code_diff
def get_module_dependencies(module_fname):
def get_modified_python_files(diff_with_last_commit=False):
"""
Get the dependencies of a module.
Return a list of python files that have been modified between:
- the current head and the main branch if `diff_with_last_commit=False` (default)
- the current head and its parent commit otherwise.
"""
with open(os.path.join(PATH_TO_TRANFORMERS, module_fname), "r", encoding="utf-8") as f:
repo = Repo(PATH_TO_REPO)
if not diff_with_last_commit:
print(f"main is at {repo.refs.main.commit}")
print(f"Current head is at {repo.head.commit}")
branching_commits = repo.merge_base(repo.refs.main, repo.head)
for commit in branching_commits:
print(f"Branching commit: {commit}")
return get_diff(repo, repo.head.commit, branching_commits)
else:
print(f"main is at {repo.head.commit}")
parent_commits = repo.head.commit.parents
for commit in parent_commits:
print(f"Parent commit: {commit}")
return get_diff(repo, repo.head.commit, parent_commits)
# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
# \s*from\s+(\.+\S+)\s+import\s+([^\n]+) -> Line only contains from .xxx import yyy and we catch .xxx and yyy
# (?=\n) -> Look-ahead to a new line. We can't just put \n here or using find_all on this re will only catch every
# other import.
_re_single_line_relative_imports = re.compile(r"(?:^|\n)\s*from\s+(\.+\S+)\s+import\s+([^\n]+)(?=\n)")
# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
# \s*from\s+(\.+\S+)\s+import\s+\(([^\)]+)\) -> Line continues with from .xxx import (yyy) and we catch .xxx and yyy
# yyy will take multiple lines otherwise there wouldn't be parenthesis.
_re_multi_line_relative_imports = re.compile(r"(?:^|\n)\s*from\s+(\.+\S+)\s+import\s+\(([^\)]+)\)")
# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
# \s*from\s+transformers(\S*)\s+import\s+([^\n]+) -> Line only contains from transformers.xxx import yyy and we catch
# .xxx and yyy
# (?=\n) -> Look-ahead to a new line. We can't just put \n here or using find_all on this re will only catch every
# other import.
_re_single_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*)\s+import\s+([^\n]+)(?=\n)")
# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
# \s*from\s+transformers(\S*)\s+import\s+\(([^\)]+)\) -> Line continues with from transformers.xxx import (yyy) and we
# catch .xxx and yyy. yyy will take multiple lines otherwise there wouldn't be parenthesis.
_re_multi_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*)\s+import\s+\(([^\)]+)\)")
def extract_imports(module_fname, cache=None):
"""
Get the imports a given module makes. This takes a module filename and returns the list of module filenames
imported in the module with the objects imported in that module filename.
"""
if cache is not None and module_fname in cache:
return cache[module_fname]
with open(PATH_TO_REPO / module_fname, "r", encoding="utf-8") as f:
content = f.read()
module_parts = module_fname.split(os.path.sep)
# Filter out all docstrings to not get imports in code examples.
splits = content.split('"""')
content = "".join(splits[::2])
module_parts = str(module_fname).split(os.path.sep)
imported_modules = []
# Let's start with relative imports
relative_imports = re.findall(r"from\s+(\.+\S+)\s+import\s+([^\n]+)\n", content)
relative_imports = [mod for mod, imp in relative_imports if "# tests_ignore" not in imp]
for imp in relative_imports:
relative_imports = _re_single_line_relative_imports.findall(content)
relative_imports = [
(mod, imp) for mod, imp in relative_imports if "# tests_ignore" not in imp and imp.strip() != "("
]
multiline_relative_imports = _re_multi_line_relative_imports.findall(content)
relative_imports += [(mod, imp) for mod, imp in multiline_relative_imports if "# tests_ignore" not in imp]
for module, imports in relative_imports:
level = 0
while imp.startswith("."):
imp = imp[1:]
while module.startswith("."):
module = module[1:]
level += 1
if len(imp) > 0:
dep_parts = module_parts[: len(module_parts) - level] + imp.split(".")
if len(module) > 0:
dep_parts = module_parts[: len(module_parts) - level] + module.split(".")
else:
dep_parts = module_parts[: len(module_parts) - level] + ["__init__.py"]
dep_parts = module_parts[: len(module_parts) - level]
imported_module = os.path.sep.join(dep_parts)
# We ignore the main init import as it's only for the __version__ that it's done
# and it would add everything as a dependency.
if not imported_module.endswith("transformers/__init__.py"):
imported_modules.append(imported_module)
imported_modules.append((imported_module, [imp.strip() for imp in imports.split(",")]))
# Let's continue with direct imports
# The import from the transformers module are ignored for the same reason we ignored the
# main init before.
direct_imports = re.findall(r"from\s+transformers\.(\S+)\s+import\s+([^\n]+)\n", content)
direct_imports = [mod for mod, imp in direct_imports if "# tests_ignore" not in imp]
for imp in direct_imports:
import_parts = imp.split(".")
direct_imports = _re_single_line_direct_imports.findall(content)
direct_imports = [(mod, imp) for mod, imp in direct_imports if "# tests_ignore" not in imp and imp.strip() != "("]
multiline_direct_imports = _re_multi_line_direct_imports.findall(content)
direct_imports += [(mod, imp) for mod, imp in multiline_direct_imports if "# tests_ignore" not in imp]
for module, imports in direct_imports:
import_parts = module.split(".")[1:] # ignore the first .
dep_parts = ["src", "transformers"] + import_parts
imported_modules.append(os.path.sep.join(dep_parts))
imported_module = os.path.sep.join(dep_parts)
imported_modules.append((imported_module, [imp.strip() for imp in imports.split(",")]))
# Now let's just check that we have proper module files, or append an init for submodules
result = []
for module_file, imports in imported_modules:
if (PATH_TO_REPO / f"{module_file}.py").is_file():
module_file = f"{module_file}.py"
elif (PATH_TO_REPO / module_file).is_dir() and (PATH_TO_REPO / module_file / "__init__.py").is_file():
module_file = os.path.sep.join([module_file, "__init__.py"])
imports = [imp for imp in imports if len(imp) > 0 and re.match("^[A-Za-z0-9_]*$", imp)]
if len(imports) > 0:
result.append((module_file, imports))
if cache is not None:
cache[module_fname] = result
return result
def get_module_dependencies(module_fname, cache=None):
"""
Get the dependencies of a module from the module filename as a list of module filenames. This will resolve any
__init__ we pass: if we import from a submodule utils, the dependencies will be utils/foo.py and utils/bar.py (if
the objects imported actually come from utils.foo and utils.bar) not utils/__init__.py.
"""
dependencies = []
for imported_module in imported_modules:
if os.path.isfile(os.path.join(PATH_TO_TRANFORMERS, f"{imported_module}.py")):
dependencies.append(f"{imported_module}.py")
elif os.path.isdir(os.path.join(PATH_TO_TRANFORMERS, imported_module)) and os.path.isfile(
os.path.sep.join([PATH_TO_TRANFORMERS, imported_module, "__init__.py"])
):
dependencies.append(os.path.sep.join([imported_module, "__init__.py"]))
imported_modules = extract_imports(module_fname, cache=cache)
# The while loop is to recursively traverse all inits we may encounter.
while len(imported_modules) > 0:
new_modules = []
for module, imports in imported_modules:
# If we end up in an __init__ we are often not actually importing from this init (except in the case where
# the object is fully defined in the __init__)
if module.endswith("__init__.py"):
# So we get the imports from that init then try to find where our objects come from.
new_imported_modules = extract_imports(module, cache=cache)
for new_module, new_imports in new_imported_modules:
if any([i in new_imports for i in imports]):
if new_module not in dependencies:
new_modules.append((new_module, [i for i in new_imports if i in imports]))
imports = [i for i in imports if i not in new_imports]
if len(imports) > 0:
# If there are any objects lefts, they may be a submodule
path_to_module = PATH_TO_REPO / module.replace("__init__.py", "")
dependencies.extend(
[
os.path.join(module.replace("__init__.py", ""), f"{i}.py")
for i in imports
if (path_to_module / f"{i}.py").is_file()
]
)
imports = [i for i in imports if not (path_to_module / f"{i}.py").is_file()]
if len(imports) > 0:
# Then if there are still objects left, they are fully defined in the init, so we keep it as a
# dependency.
dependencies.append(module)
else:
dependencies.append(module)
imported_modules = new_modules
return dependencies
def get_test_dependencies(test_fname):
"""
Get the dependencies of a test file.
"""
with open(os.path.join(PATH_TO_TRANFORMERS, test_fname), "r", encoding="utf-8") as f:
content = f.read()
# Tests only have relative imports for other test files
# TODO Sylvain: handle relative imports cleanly
relative_imports = re.findall(r"from\s+(\.\S+)\s+import\s+([^\n]+)\n", content)
relative_imports = [test for test, imp in relative_imports if "# tests_ignore" not in imp]
def _convert_relative_import_to_file(relative_import):
level = 0
while relative_import.startswith("."):
level += 1
relative_import = relative_import[1:]
directory = os.path.sep.join(test_fname.split(os.path.sep)[:-level])
return os.path.join(directory, f"{relative_import.replace('.', os.path.sep)}.py")
dependencies = [_convert_relative_import_to_file(relative_import) for relative_import in relative_imports]
return [f for f in dependencies if os.path.isfile(os.path.join(PATH_TO_TRANFORMERS, f))]
def create_reverse_dependency_tree():
"""
Create a list of all edges (a, b) which mean that modifying a impacts b with a going over all module and test files.
"""
modules = [
str(f.relative_to(PATH_TO_TRANFORMERS))
for f in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py")
]
module_edges = [(d, m) for m in modules for d in get_module_dependencies(m)]
cache = {}
all_modules = list(PATH_TO_TRANFORMERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py"))
all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules]
edges = [(dep, mod) for mod in all_modules for dep in get_module_dependencies(mod, cache=cache)]
tests = [str(f.relative_to(PATH_TO_TRANFORMERS)) for f in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/*.py")]
test_edges = [(d, t) for t in tests for d in get_test_dependencies(t)]
return module_edges + test_edges
return list(set(edges))
def get_tree_starting_at(module, edges):
@@ -264,13 +365,17 @@ def get_tree_starting_at(module, edges):
starting at module], [list of edges starting at the preceding level], ...]
"""
vertices_seen = [module]
new_edges = [edge for edge in edges if edge[0] == module and edge[1] != module]
new_edges = [edge for edge in edges if edge[0] == module and edge[1] != module and "__init__.py" not in edge[1]]
tree = [module]
while len(new_edges) > 0:
tree.append(new_edges)
final_vertices = list({edge[1] for edge in new_edges})
vertices_seen.extend(final_vertices)
new_edges = [edge for edge in edges if edge[0] in final_vertices and edge[1] not in vertices_seen]
new_edges = [
edge
for edge in edges
if edge[0] in final_vertices and edge[1] not in vertices_seen and "__init__.py" not in edge[1]
]
return tree
@@ -308,290 +413,159 @@ def create_reverse_dependency_map():
Create the dependency map from module/test filename to the list of modules/tests that depend on it (even
recursively).
"""
modules = [
str(f.relative_to(PATH_TO_TRANFORMERS))
for f in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py")
]
# We grab all the dependencies of each module.
direct_deps = {m: get_module_dependencies(m) for m in modules}
# We add all the dependencies of each test file
tests = [str(f.relative_to(PATH_TO_TRANFORMERS)) for f in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/*.py")]
direct_deps.update({t: get_test_dependencies(t) for t in tests})
all_files = modules + tests
cache = {}
all_modules = list(PATH_TO_TRANFORMERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py"))
all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules]
direct_deps = {m: get_module_dependencies(m, cache=cache) for m in all_modules}
# This recurses the dependencies
something_changed = True
while something_changed:
something_changed = False
for m in all_files:
for m in all_modules:
for d in direct_deps[m]:
if d.endswith("__init__.py"):
continue
if d not in direct_deps:
raise ValueError(f"KeyError:{d}. From {m}")
for dep in direct_deps[d]:
if dep not in direct_deps[m]:
direct_deps[m].append(dep)
something_changed = True
new_deps = set(direct_deps[d]) - set(direct_deps[m])
if len(new_deps) > 0:
direct_deps[m].extend(list(new_deps))
something_changed = True
# Finally we can build the reverse map.
reverse_map = collections.defaultdict(list)
for m in all_files:
if m.endswith("__init__.py"):
reverse_map[m].extend(direct_deps[m])
for m in all_modules:
for d in direct_deps[m]:
reverse_map[d].append(m)
for m in [f for f in all_modules if f.endswith("__init__.py")]:
direct_deps = get_module_dependencies(m, cache=cache)
deps = sum([reverse_map[d] for d in direct_deps if not d.endswith("__init__.py")], direct_deps)
reverse_map[m] = list(set(deps) - {m})
return reverse_map
# Any module file that has a test name which can't be inferred automatically from its name should go here. A better
# approach is to (re-)name the test file accordingly, and second best to add the correspondence map here.
SPECIAL_MODULE_TO_TEST_MAP = {
"commands/add_new_model_like.py": "utils/test_add_new_model_like.py",
"configuration_utils.py": "test_configuration_common.py",
"convert_graph_to_onnx.py": "onnx/test_onnx.py",
"data/data_collator.py": "trainer/test_data_collator.py",
"deepspeed.py": "deepspeed/",
"feature_extraction_sequence_utils.py": "test_sequence_feature_extraction_common.py",
"feature_extraction_utils.py": "test_feature_extraction_common.py",
"file_utils.py": ["utils/test_file_utils.py", "utils/test_model_output.py"],
"image_processing_utils.py": ["test_image_processing_common.py", "utils/test_image_processing_utils.py"],
"image_transforms.py": "test_image_transforms.py",
"utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py", "utils/test_generic.py"],
"utils/hub.py": "utils/test_hub_utils.py",
"modelcard.py": "utils/test_model_card.py",
"modeling_flax_utils.py": "test_modeling_flax_common.py",
"modeling_tf_utils.py": ["test_modeling_tf_common.py", "utils/test_modeling_tf_core.py"],
"modeling_utils.py": ["test_modeling_common.py", "utils/test_offline.py"],
"models/auto/modeling_auto.py": [
"models/auto/test_modeling_auto.py",
"models/auto/test_modeling_tf_pytorch.py",
"models/bort/test_modeling_bort.py",
"models/dit/test_modeling_dit.py",
],
"models/auto/modeling_flax_auto.py": "models/auto/test_modeling_flax_auto.py",
"models/auto/modeling_tf_auto.py": [
"models/auto/test_modeling_tf_auto.py",
"models/auto/test_modeling_tf_pytorch.py",
"models/bort/test_modeling_tf_bort.py",
],
"models/gpt2/modeling_gpt2.py": [
"models/gpt2/test_modeling_gpt2.py",
"models/megatron_gpt2/test_modeling_megatron_gpt2.py",
],
"models/dpt/modeling_dpt.py": [
"models/dpt/test_modeling_dpt.py",
"models/dpt/test_modeling_dpt_hybrid.py",
],
"optimization.py": "optimization/test_optimization.py",
"optimization_tf.py": "optimization/test_optimization_tf.py",
"pipelines/__init__.py": all_pipeline_test_files + all_model_test_files,
"pipelines/base.py": all_pipeline_test_files + all_model_test_files,
"pipelines/text2text_generation.py": [
"pipelines/test_pipelines_text2text_generation.py",
"pipelines/test_pipelines_summarization.py",
"pipelines/test_pipelines_translation.py",
],
"pipelines/zero_shot_classification.py": "pipelines/test_pipelines_zero_shot.py",
"testing_utils.py": "utils/test_skip_decorators.py",
"tokenization_utils.py": ["test_tokenization_common.py", "tokenization/test_tokenization_utils.py"],
"tokenization_utils_base.py": ["test_tokenization_common.py", "tokenization/test_tokenization_utils.py"],
"tokenization_utils_fast.py": [
"test_tokenization_common.py",
"tokenization/test_tokenization_utils.py",
"tokenization/test_tokenization_fast.py",
],
"trainer.py": [
"trainer/test_trainer.py",
"extended/test_trainer_ext.py",
"trainer/test_trainer_distributed.py",
"trainer/test_trainer_tpu.py",
],
"train_pt_utils.py": "trainer/test_trainer_utils.py",
"utils/versions.py": "utils/test_versions_utils.py",
}
def module_to_test_file(module_fname):
def create_module_to_test_map(reverse_map=None, filter_models=False):
"""
Returns the name of the file(s) where `module_fname` is tested.
Extract the tests from the reverse_dependency_map and potentially filters the model tests.
"""
splits = module_fname.split(os.path.sep)
if reverse_map is None:
reverse_map = create_reverse_dependency_map()
test_map = {module: [f for f in deps if f.startswith("tests")] for module, deps in reverse_map.items()}
# Special map has priority
short_name = os.path.sep.join(splits[2:])
if short_name in SPECIAL_MODULE_TO_TEST_MAP:
test_file = SPECIAL_MODULE_TO_TEST_MAP[short_name]
if isinstance(test_file, str):
return f"tests/{test_file}"
return [f"tests/{f}" for f in test_file]
if not filter_models:
return test_map
module_name = splits[-1]
# Fast tokenizers are tested in the same file as the slow ones.
if module_name.endswith("_fast.py"):
module_name = module_name.replace("_fast.py", ".py")
num_model_tests = len(list(PATH_TO_TESTS.glob("models/*")))
# Special case for pipelines submodules
if len(splits) >= 2 and splits[-2] == "pipelines":
default_test_file = f"tests/pipelines/test_pipelines_{module_name}"
return [default_test_file] + all_model_test_files
# Special case for benchmarks submodules
elif len(splits) >= 2 and splits[-2] == "benchmark":
return ["tests/benchmark/test_benchmark.py", "tests/benchmark/test_benchmark_tf.py"]
# Special case for commands submodules
elif len(splits) >= 2 and splits[-2] == "commands":
return "tests/utils/test_cli.py"
# Special case for onnx submodules
elif len(splits) >= 2 and splits[-2] == "onnx":
return ["tests/onnx/test_features.py", "tests/onnx/test_onnx.py", "tests/onnx/test_onnx_v2.py"]
# Special case for utils (not the one in src/transformers, the ones at the root of the repo).
elif len(splits) > 0 and splits[0] == "utils":
default_test_file = f"tests/repo_utils/test_{module_name}"
elif len(splits) > 4 and splits[2] == "models":
default_test_file = f"tests/models/{splits[3]}/test_{module_name}"
elif len(splits) > 2 and splits[2].startswith("generation"):
default_test_file = f"tests/generation/test_{module_name}"
elif len(splits) > 2 and splits[2].startswith("trainer"):
default_test_file = f"tests/trainer/test_{module_name}"
else:
default_test_file = f"tests/utils/test_{module_name}"
def has_many_models(tests):
model_tests = {Path(t).parts[2] for t in tests if t.startswith("tests/models/")}
return len(model_tests) > num_model_tests // 2
if os.path.isfile(default_test_file):
return default_test_file
def filter_tests(tests):
return [t for t in tests if not t.startswith("tests/models/") or Path(t).parts[2] in IMPORTANT_MODELS]
# Processing -> processor
if "processing" in default_test_file:
test_file = default_test_file.replace("processing", "processor")
if os.path.isfile(test_file):
return test_file
return {module: (filter_tests(tests) if has_many_models(tests) else tests) for module, tests in test_map.items()}
# This list contains the list of test files we expect never to be launched from a change in a module/util. Those are
# launched separately.
EXPECTED_TEST_FILES_NEVER_TOUCHED = [
"tests/generation/test_framework_agnostic.py", # Mixins inherited by actual test classes
"tests/mixed_int8/test_mixed_int8.py", # Mixed-int8 bitsandbytes test
"tests/pipelines/test_pipelines_common.py", # Actually checked by the pipeline based file
"tests/sagemaker/test_single_node_gpu.py", # SageMaker test
"tests/sagemaker/test_multi_node_model_parallel.py", # SageMaker test
"tests/sagemaker/test_multi_node_data_parallel.py", # SageMaker test
"tests/test_pipeline_mixin.py", # Contains no test of its own (only the common tester class)
"tests/utils/test_doc_samples.py", # Doc tests
]
def check_imports_all_exist():
"""
Isn't used per se by the test fetcher but might be used later as a quality check. Putting this here for now so the
code is not lost.
"""
cache = {}
all_modules = list(PATH_TO_TRANFORMERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py"))
all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules]
direct_deps = {m: get_module_dependencies(m, cache=cache) for m in all_modules}
for module, deps in direct_deps.items():
for dep in deps:
if not (PATH_TO_REPO / dep).is_file():
print(f"{module} has dependency on {dep} which does not exist.")
def _print_list(l):
return "\n".join([f"- {f}" for f in l])
def sanity_check():
"""
Checks that all test files can be touched by a modification in at least one module/utils. This test ensures that
newly-added test files are properly mapped to some module or utils, so they can be run by the CI.
"""
# Grab all module and utils
all_files = [
str(p.relative_to(PATH_TO_TRANFORMERS))
for p in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py")
]
all_files += [
str(p.relative_to(PATH_TO_TRANFORMERS)) for p in (Path(PATH_TO_TRANFORMERS) / "utils").glob("**/*.py")
]
def create_json_map(test_files_to_run, json_output_file):
if json_output_file is None:
return
# Compute all the test files we get from those.
test_files_found = []
for f in all_files:
test_f = module_to_test_file(f)
if test_f is not None:
if isinstance(test_f, str):
test_files_found.append(test_f)
else:
test_files_found.extend(test_f)
# Some of the test files might actually be subfolders so we grab the tests inside.
test_files = []
for test_f in test_files_found:
if os.path.isdir(os.path.join(PATH_TO_TRANFORMERS, test_f)):
test_files.extend(
[
str(p.relative_to(PATH_TO_TRANFORMERS))
for p in (Path(PATH_TO_TRANFORMERS) / test_f).glob("**/test*.py")
]
)
test_map = {}
for test_file in test_files_to_run:
# `test_file` is a path to a test folder/file, starting with `tests/`. For example,
# - `tests/models/bert/test_modeling_bert.py` or `tests/models/bert`
# - `tests/trainer/test_trainer.py` or `tests/trainer`
# - `tests/test_modeling_common.py`
names = test_file.split(os.path.sep)
if names[1] == "models":
# take the part like `models/bert` for modeling tests
key = os.path.sep.join(names[1:3])
elif len(names) > 2 or not test_file.endswith(".py"):
# test folders under `tests` or python files under them
# take the part like tokenization, `pipeline`, etc. for other test categories
key = os.path.sep.join(names[1:2])
else:
test_files.append(test_f)
# common test files directly under `tests/`
key = "common"
# Compare to existing test files
existing_test_files = [
str(p.relative_to(PATH_TO_TRANFORMERS)) for p in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/test*.py")
]
not_touched_test_files = [f for f in existing_test_files if f not in test_files]
if key not in test_map:
test_map[key] = []
test_map[key].append(test_file)
should_be_tested = set(not_touched_test_files) - set(EXPECTED_TEST_FILES_NEVER_TOUCHED)
if len(should_be_tested) > 0:
raise ValueError(
"The following test files are not currently associated with any module or utils files, which means they "
f"will never get run by the CI:\n{_print_list(should_be_tested)}\n. Make sure the names of these test "
"files match the name of the module or utils they are testing, or adapt the constant "
"`SPECIAL_MODULE_TO_TEST_MAP` in `utils/tests_fetcher.py` to add them. If your test file is triggered "
"separately and is not supposed to be run by the regular CI, add it to the "
"`EXPECTED_TEST_FILES_NEVER_TOUCHED` constant instead."
)
# sort the keys & values
keys = sorted(test_map.keys())
test_map = {k: " ".join(sorted(test_map[k])) for k in keys}
with open(json_output_file, "w", encoding="UTF-8") as fp:
json.dump(test_map, fp, ensure_ascii=False)
def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None, json_output_file=None):
def infer_tests_to_run(
output_file, diff_with_last_commit=False, filters=None, filter_models=True, json_output_file=None
):
modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}")
# Create the map that will give us all impacted modules.
impacted_modules_map = create_reverse_dependency_map()
reverse_map = create_reverse_dependency_map()
impacted_files = modified_files.copy()
for f in modified_files:
if f in impacted_modules_map:
impacted_files.extend(impacted_modules_map[f])
if f in reverse_map:
impacted_files.extend(reverse_map[f])
# Remove duplicates
impacted_files = sorted(set(impacted_files))
print(f"\n### IMPACTED FILES ###\n{_print_list(impacted_files)}")
# Grab the corresponding test files:
if "setup.py" in impacted_files:
if "setup.py" in modified_files:
test_files_to_run = ["tests"]
repo_utils_launch = True
else:
# Grab the corresponding test files:
test_files_to_run = []
for f in impacted_files:
# Modified test files are always added
if f.startswith("tests/"):
test_files_to_run.append(f)
# Example files are tested separately
elif f.startswith("examples/pytorch"):
test_files_to_run.append("examples/pytorch/test_pytorch_examples.py")
test_files_to_run.append("examples/pytorch/test_accelerate_examples.py")
elif f.startswith("examples/tensorflow"):
test_files_to_run.append("examples/tensorflow/test_tensorflow_examples.py")
elif f.startswith("examples/flax"):
test_files_to_run.append("examples/flax/test_flax_examples.py")
else:
new_tests = module_to_test_file(f)
if new_tests is not None:
if isinstance(new_tests, str):
test_files_to_run.append(new_tests)
else:
test_files_to_run.extend(new_tests)
# Remove duplicates
# All modified tests need to be run.
test_files_to_run = [
f for f in modified_files if f.startswith("tests") and f.split(os.path.sep)[-1].startswith("test")
]
# Then we grab the corresponding test files.
test_map = create_module_to_test_map(reverse_map=reverse_map, filter_models=filter_models)
for f in modified_files:
if f in test_map:
test_files_to_run.extend(test_map[f])
test_files_to_run = sorted(set(test_files_to_run))
# Remove SageMaker tests
test_files_to_run = [f for f in test_files_to_run if not f.split(os.path.sep)[1] == "sagemaker"]
# Make sure we did not end up with a test file that was removed
test_files_to_run = [f for f in test_files_to_run if os.path.isfile(f) or os.path.isdir(f)]
test_files_to_run = [f for f in test_files_to_run if (PATH_TO_REPO / f).exists()]
if filters is not None:
filtered_files = []
for filter in filters:
filtered_files.extend([f for f in test_files_to_run if f.startswith(filter)])
for _filter in filters:
filtered_files.extend([f for f in test_files_to_run if f.startswith(_filter)])
test_files_to_run = filtered_files
repo_utils_launch = any(f.split(os.path.sep)[1] == "repo_utils" for f in test_files_to_run)
repo_utils_launch = any(f.split(os.path.sep)[1] == "repo_utils" for f in modified_files)
if repo_utils_launch:
repo_util_file = Path(output_file).parent / "test_repo_utils.txt"
@@ -610,34 +584,7 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None, j
if "tests" in test_files_to_run:
test_files_to_run = get_all_tests()
if json_output_file is not None:
test_map = {}
for test_file in test_files_to_run:
# `test_file` is a path to a test folder/file, starting with `tests/`. For example,
# - `tests/models/bert/test_modeling_bert.py` or `tests/models/bert`
# - `tests/trainer/test_trainer.py` or `tests/trainer`
# - `tests/test_modeling_common.py`
names = test_file.split(os.path.sep)
if names[1] == "models":
# take the part like `models/bert` for modeling tests
key = "/".join(names[1:3])
elif len(names) > 2 or not test_file.endswith(".py"):
# test folders under `tests` or python files under them
# take the part like tokenization, `pipeline`, etc. for other test categories
key = "/".join(names[1:2])
else:
# common test files directly under `tests/`
key = "common"
if key not in test_map:
test_map[key] = []
test_map[key].append(test_file)
# sort the keys & values
keys = sorted(test_map.keys())
test_map = {k: " ".join(sorted(test_map[k])) for k in keys}
with open(json_output_file, "w", encoding="UTF-8") as fp:
json.dump(test_map, fp, ensure_ascii=False)
create_json_map(test_files_to_run, json_output_file)
def filter_tests(output_file, filters):
@@ -667,11 +614,29 @@ def filter_tests(output_file, filters):
f.write(" ".join(test_files))
def parse_commit_message(commit_message):
"""
Parses the commit message to detect if a command is there to skip, force all or part of the CI.
Returns a dictionary of strings to bools with keys skip, test_all_models and test_all.
"""
if commit_message is None:
return {"skip": False, "no_filter": False, "test_all": False}
command_search = re.search(r"\[([^\]]*)\]", commit_message)
if command_search is not None:
command = command_search.groups()[0]
command = command.lower().replace("-", " ").replace("_", " ")
skip = command in ["ci skip", "skip ci", "circleci skip", "skip circleci"]
no_filter = set(command.split(" ")) == {"no", "filter"}
test_all = set(command.split(" ")) == {"test", "all"}
return {"skip": skip, "no_filter": no_filter, "test_all": test_all}
else:
return {"skip": False, "no_filter": False, "test_all": False}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--sanity_check", action="store_true", help="Only test that all tests and modules are accounted for."
)
parser.add_argument(
"--output_file", type=str, default="test_list.txt", help="Where to store the list of tests to run"
)
@@ -704,33 +669,54 @@ if __name__ == "__main__":
help="Will only print the tree of modules depending on the file passed.",
default=None,
)
parser.add_argument(
"--commit_message",
type=str,
help="The commit message (which could contain a command to force all tests or skip the CI).",
default=None,
)
args = parser.parse_args()
if args.print_dependencies_of is not None:
print_tree_deps_of(args.print_dependencies_of)
elif args.sanity_check:
sanity_check()
elif args.filter_tests:
filter_tests(args.output_file, ["pipelines", "repo_utils"])
else:
repo = Repo(PATH_TO_TRANFORMERS)
repo = Repo(PATH_TO_REPO)
commit_message = repo.head.commit.message
commit_flags = parse_commit_message(commit_message)
if commit_flags["skip"]:
print("Force-skipping the CI")
quit()
if commit_flags["no_filter"]:
print("Running all tests fetched without filtering.")
if commit_flags["test_all"]:
print("Force-launching all tests")
diff_with_last_commit = args.diff_with_last_commit
if not diff_with_last_commit and not repo.head.is_detached and repo.head.ref == repo.refs.main:
print("main branch detected, fetching tests against last commit.")
diff_with_last_commit = True
try:
infer_tests_to_run(
args.output_file,
diff_with_last_commit=diff_with_last_commit,
filters=args.filters,
json_output_file=args.json_output_file,
)
filter_tests(args.output_file, ["repo_utils"])
except Exception as e:
print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.")
if not commit_flags["test_all"]:
try:
infer_tests_to_run(
args.output_file,
diff_with_last_commit=diff_with_last_commit,
filters=args.filters,
json_output_file=args.json_output_file,
filter_models=not commit_flags["no_filter"],
)
filter_tests(args.output_file, ["repo_utils"])
except Exception as e:
print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.")
commit_flags["test_all"] = True
if commit_flags["test_all"]:
with open(args.output_file, "w", encoding="utf-8") as f:
if args.filters is None:
f.write("./tests/")
else:
f.write(" ".join(args.filters))
test_files_to_run = get_all_tests()
create_json_map(test_files_to_run, args.json_output_file)