diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 51246b628b..a455a31038 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -140,6 +140,8 @@ class ModelTesterMixin: return inputs_dict def test_save_load(self): + # Fake modif, will remove before merging. + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index e518a29b22..3080af5be1 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -187,9 +187,23 @@ def get_module_dependencies(module_fname): return dependencies +def get_test_dependencies(test_fname): + """ + Get the dependencies of a test file. + """ + with open(os.path.join(PATH_TO_TRANFORMERS, test_fname), "r", encoding="utf-8") as f: + content = f.read() + + # Tests only have relative imports for other test files + relative_imports = re.findall(r"from\s+\.(\S+)\s+import\s+([^\n]+)\n", content) + relative_imports = [test for test, imp in relative_imports if "# tests_ignore" not in imp] + return [os.path.join("tests", f"{test}.py") for test in relative_imports] + + def create_reverse_dependency_map(): """ - Create the dependency map from module filename to the list of modules that depend on it (even recursively). + Create the dependency map from module/test filename to the list of modules/tests that depend on it (even + recursively). """ modules = [ str(f.relative_to(PATH_TO_TRANFORMERS)) @@ -198,11 +212,17 @@ def create_reverse_dependency_map(): # We grab all the dependencies of each module. direct_deps = {m: get_module_dependencies(m) for m in modules} + # We add all the dependencies of each test file + tests = [str(f.relative_to(PATH_TO_TRANFORMERS)) for f in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/*.py")] + direct_deps.update({t: get_test_dependencies(t) for t in tests}) + + all_files = modules + tests + # This recurses the dependencies something_changed = True while something_changed: something_changed = False - for m in modules: + for m in all_files: for d in direct_deps[m]: for dep in direct_deps[d]: if dep not in direct_deps[m]: @@ -211,7 +231,7 @@ def create_reverse_dependency_map(): # Finally we can build the reverse map. reverse_map = collections.defaultdict(list) - for m in modules: + for m in all_files: if m.endswith("__init__.py"): reverse_map[m].extend(direct_deps[m]) for d in direct_deps[m]: