No more Tuple, List, Dict (#38797)

* No more Tuple, List, Dict

* make fixup

* More style fixes

* Docstring fixes with regex replacement

* Trigger tests

* Redo fixes after rebase

* Fix copies

* [test all]

* update

* [test all]

* update

* [test all]

* make style after rebase

* Patch the hf_argparser test

* Patch the hf_argparser test

* style fixes

* style fixes

* style fixes

* Fix docstrings in Cohere test

* [test all]

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Matt
2025-06-17 19:37:18 +01:00
committed by GitHub
parent a396f4324b
commit 508a704055
1291 changed files with 14906 additions and 14941 deletions

View File

@@ -42,7 +42,7 @@ import os
import re
import subprocess
from collections import OrderedDict
from typing import List, Optional, Tuple, Union
from typing import Optional, Union
from transformers.utils import direct_transformers_import
@@ -253,7 +253,7 @@ def _sanity_check_splits(splits_1, splits_2, is_class, filename):
raise ValueError(f"In {filename}, two code blocks expected to be copies have different structures.")
def find_block_end(lines: List[str], start_index: int, indent: int) -> int:
def find_block_end(lines: list[str], start_index: int, indent: int) -> int:
"""
Find the end of the class/func block starting at `start_index` in a source code (defined by `lines`).
@@ -282,8 +282,8 @@ def find_block_end(lines: List[str], start_index: int, indent: int) -> int:
def split_code_into_blocks(
lines: List[str], start_index: int, end_index: int, indent: int, backtrace: bool = False
) -> List[Tuple[str, int, int]]:
lines: list[str], start_index: int, end_index: int, indent: int, backtrace: bool = False
) -> list[tuple[str, int, int]]:
"""
Split the class/func block starting at `start_index` in a source code (defined by `lines`) into *inner blocks*.
@@ -391,7 +391,7 @@ def split_code_into_blocks(
def find_code_in_transformers(
object_name: str, base_path: Optional[str] = None, return_indices: bool = False
) -> Union[str, Tuple[List[str], int, int]]:
) -> Union[str, tuple[list[str], int, int]]:
"""
Find and return the source code of an object.
@@ -640,7 +640,7 @@ def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int
def is_copy_consistent(
filename: str, overwrite: bool = False, buffer: Optional[dict] = None
) -> Optional[List[Tuple[str, int]]]:
) -> Optional[list[tuple[str, int]]]:
"""
Check if the code commented as a copy in a file matches the original.
@@ -936,7 +936,7 @@ def get_model_list(filename: str, start_prompt: str, end_prompt: str) -> str:
return "".join(result)
def convert_to_localized_md(model_list: str, localized_model_list: str, format_str: str) -> Tuple[bool, str]:
def convert_to_localized_md(model_list: str, localized_model_list: str, format_str: str) -> tuple[bool, str]:
"""
Compare the model list from the main README to the one in a localized README.

View File

@@ -33,7 +33,6 @@ python utils/check_doc_toc.py --fix_and_overwrite
import argparse
from collections import defaultdict
from typing import List
import yaml
@@ -41,7 +40,7 @@ import yaml
PATH_TO_TOC = "docs/source/en/_toctree.yml"
def clean_model_doc_toc(model_doc: List[dict]) -> List[dict]:
def clean_model_doc_toc(model_doc: list[dict]) -> list[dict]:
"""
Cleans a section of the table of content of the model documentation (one specific modality) by removing duplicates
and sorting models alphabetically.

View File

@@ -42,7 +42,7 @@ import operator as op
import os
import re
from pathlib import Path
from typing import Any, Optional, Tuple, Union
from typing import Any, Optional, Union
from check_repo import ignore_undocumented
from git import Repo
@@ -788,7 +788,7 @@ def find_source_file(obj: Any) -> Path:
return obj_file.with_suffix(".py")
def match_docstring_with_signature(obj: Any) -> Optional[Tuple[str, str]]:
def match_docstring_with_signature(obj: Any) -> Optional[tuple[str, str]]:
"""
Matches the docstring of an object with its signature.

View File

@@ -37,7 +37,7 @@ python utils/check_dummies.py --fix_and_overwrite
import argparse
import os
import re
from typing import Dict, List, Optional
from typing import Optional
# All paths are set with the intent you should run this script from the root of the repo with the command
@@ -92,7 +92,7 @@ def find_backend(line: str) -> Optional[str]:
return "_and_".join(backends)
def read_init() -> Dict[str, List[str]]:
def read_init() -> dict[str, list[str]]:
"""
Read the init and extract backend-specific objects.
@@ -156,7 +156,7 @@ def create_dummy_object(name: str, backend_name: str) -> str:
return DUMMY_CLASS.format(name, backend_name)
def create_dummy_files(backend_specific_objects: Optional[Dict[str, List[str]]] = None) -> Dict[str, str]:
def create_dummy_files(backend_specific_objects: Optional[dict[str, list[str]]] = None) -> dict[str, str]:
"""
Create the content of the dummy files.

View File

@@ -39,7 +39,7 @@ import collections
import os
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Optional
# Path is set with the intent you should run this script from the root of the repo.
@@ -89,7 +89,7 @@ def find_backend(line: str) -> Optional[str]:
return "_and_".join(backends)
def parse_init(init_file) -> Optional[Tuple[Dict[str, List[str]], Dict[str, List[str]]]]:
def parse_init(init_file) -> Optional[tuple[dict[str, list[str]], dict[str, list[str]]]]:
"""
Read an init_file and parse (per backend) the `_import_structure` objects defined and the `TYPE_CHECKING` objects
defined.
@@ -232,7 +232,7 @@ def parse_init(init_file) -> Optional[Tuple[Dict[str, List[str]], Dict[str, List
return import_dict_objects, type_hint_objects
def analyze_results(import_dict_objects: Dict[str, List[str]], type_hint_objects: Dict[str, List[str]]) -> List[str]:
def analyze_results(import_dict_objects: dict[str, list[str]], type_hint_objects: dict[str, list[str]]) -> list[str]:
"""
Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init.
@@ -279,7 +279,7 @@ def analyze_results(import_dict_objects: Dict[str, List[str]], type_hint_objects
return errors
def get_transformers_submodules() -> List[str]:
def get_transformers_submodules() -> list[str]:
"""
Returns the list of Transformers submodules.
"""

View File

@@ -40,7 +40,6 @@ from collections import OrderedDict
from difflib import get_close_matches
from importlib.machinery import ModuleSpec
from pathlib import Path
from typing import List, Tuple
from transformers import is_flax_available, is_tf_available, is_torch_available
from transformers.models.auto.auto_factory import get_values
@@ -470,7 +469,7 @@ def check_model_list():
# If some modeling modules should be ignored for all checks, they should be added in the nested list
# _ignore_modules of this function.
def get_model_modules() -> List[str]:
def get_model_modules() -> list[str]:
"""Get all the model modules inside the transformers library (except deprecated models)."""
_ignore_modules = [
"modeling_auto",
@@ -502,7 +501,7 @@ def get_model_modules() -> List[str]:
return modules
def get_models(module: types.ModuleType, include_pretrained: bool = False) -> List[Tuple[str, type]]:
def get_models(module: types.ModuleType, include_pretrained: bool = False) -> list[tuple[str, type]]:
"""
Get the objects in a module that are models.
@@ -564,7 +563,7 @@ def check_models_are_in_init():
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
# nested list _ignore_files of this function.
def get_model_test_files() -> List[str]:
def get_model_test_files() -> list[str]:
"""
Get the model test files.
@@ -605,7 +604,7 @@ def get_model_test_files() -> List[str]:
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
# for the all_model_classes variable.
def find_tested_models(test_file: str) -> List[str]:
def find_tested_models(test_file: str) -> list[str]:
"""
Parse the content of test_file to detect what's in `all_model_classes`. This detects the models that inherit from
the common test class.
@@ -640,7 +639,7 @@ def should_be_tested(model_name: str) -> bool:
return not is_building_block(model_name)
def check_models_are_tested(module: types.ModuleType, test_file: str) -> List[str]:
def check_models_are_tested(module: types.ModuleType, test_file: str) -> list[str]:
"""Check models defined in a module are all tested in a given file.
Args:
@@ -696,7 +695,7 @@ def check_all_models_are_tested():
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
def get_all_auto_configured_models() -> List[str]:
def get_all_auto_configured_models() -> list[str]:
"""Return the list of all models in at least one auto class."""
result = set() # To avoid duplicates we concatenate all model classes in a set.
if is_torch_available():
@@ -725,7 +724,7 @@ def ignore_unautoclassed(model_name: str) -> bool:
return False
def check_models_are_auto_configured(module: types.ModuleType, all_auto_models: List[str]) -> List[str]:
def check_models_are_auto_configured(module: types.ModuleType, all_auto_models: list[str]) -> list[str]:
"""
Check models defined in module are each in an auto class.
@@ -916,7 +915,7 @@ def check_objects_being_equally_in_main_init():
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
def check_decorator_order(filename: str) -> List[int]:
def check_decorator_order(filename: str) -> list[int]:
"""
Check that in a given test file, the slow decorator is always last.
@@ -958,7 +957,7 @@ def check_all_decorator_order():
)
def find_all_documented_objects() -> List[str]:
def find_all_documented_objects() -> list[str]:
"""
Parse the content of all doc files to detect which classes and functions it documents.

View File

@@ -38,7 +38,7 @@ python utils/custom_init_isort.py --check_only
import argparse
import os
import re
from typing import Any, Callable, List, Optional
from typing import Any, Callable, Optional
# Path is defined with the intent you should run this script from the root of the repo.
@@ -64,7 +64,7 @@ def get_indent(line: str) -> str:
def split_code_in_indented_blocks(
code: str, indent_level: str = "", start_prompt: Optional[str] = None, end_prompt: Optional[str] = None
) -> List[str]:
) -> list[str]:
"""
Split some code into its indented blocks, starting at a given level.
@@ -140,7 +140,7 @@ def ignore_underscore_and_lowercase(key: Callable[[Any], str]) -> Callable[[Any]
return _inner
def sort_objects(objects: List[Any], key: Optional[Callable[[Any], str]] = None) -> List[Any]:
def sort_objects(objects: list[Any], key: Optional[Callable[[Any], str]] = None) -> list[Any]:
"""
Sort a list of objects following the rules of isort (all uppercased first, camel-cased second and lower-cased
last).

View File

@@ -9,7 +9,7 @@ import argparse
import os
from collections import defaultdict
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional
import requests
from custom_init_isort import sort_imports_in_all_inits
@@ -77,7 +77,7 @@ def insert_tip_to_model_doc(model_doc_path, tip_message):
f.write("\n".join(new_model_lines))
def get_model_doc_path(model: str) -> Tuple[Optional[str], Optional[str]]:
def get_model_doc_path(model: str) -> tuple[Optional[str], Optional[str]]:
# Possible variants of the model name in the model doc path
model_names = [model, model.replace("_", "-"), model.replace("_", "")]

View File

@@ -19,7 +19,7 @@ import os
import re
from abc import ABC, abstractmethod
from collections import Counter, defaultdict, deque
from typing import Dict, Optional, Set, Union
from typing import Optional, Union
import libcst as cst
from check_copies import run_ruff
@@ -197,7 +197,7 @@ def get_full_attribute_name(node: Union[cst.Attribute, cst.Name]) -> Optional[st
# Transformer class to replace ClassB.call_to_method and ClassB().call_to_method with super().call_to_method
class ReplaceMethodCallTransformer(cst.CSTTransformer):
def __init__(self, all_bases: Set[str]):
def __init__(self, all_bases: set[str]):
self.all_bases = all_bases
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
@@ -473,7 +473,7 @@ class SuperTransformer(cst.CSTTransformer):
def find_all_dependencies(
dependency_mapping: Dict[str, set],
dependency_mapping: dict[str, set],
start_entity: Optional[str] = None,
initial_dependencies: Optional[set] = None,
initial_checked_dependencies: Optional[set] = None,
@@ -636,11 +636,11 @@ class ModuleMapper(CSTVisitor, ABC):
def __init__(self, python_module: cst.Module):
# fmt: off
self.python_module: cst.Module = python_module # original cst.Module being visited
self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes (it will be ordered by default!!)
self.classes: dict[str, cst.ClassDef] = {} # mapping from class names to Nodes (it will be ordered by default!!)
self.imports = [] # stores all import statements
self.functions: Dict[str, cst.FunctionDef] = {} # mapping of global scope function names to Nodes
self.functions: dict[str, cst.FunctionDef] = {} # mapping of global scope function names to Nodes
self.object_dependency_mapping = defaultdict(set) # immediate function/assignment dependency mapping (i.e. dependencies immediately in the function/assignment definition)
self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes
self.assignments: dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes
self.current_function = None # this keeps track of the current module-scope function
self.current_class = None # this keeps track of the current module-scope class
self.current_assignment = None # this keeps track of the current module-scope assignment
@@ -1266,8 +1266,8 @@ class ModularFileMapper(ModuleMapper):
# fmt: off
self.model_name = new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` or `phi3`
self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"}
self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module}
self.model_specific_imported_objects: dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"}
self.model_specific_modules: dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module}
self.all_all_to_add = {}
# fmt: on

View File

@@ -21,7 +21,7 @@ import os
import re
import sys
import time
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
import requests
from get_ci_error_statistics import get_jobs
@@ -109,7 +109,7 @@ def handle_stacktraces(test_results):
return stacktraces
def dicts_to_sum(objects: Union[Dict[str, Dict], List[dict]]):
def dicts_to_sum(objects: Union[dict[str, dict], list[dict]]):
if isinstance(objects, dict):
lists = objects.values()
else:
@@ -126,9 +126,9 @@ class Message:
self,
title: str,
ci_title: str,
model_results: Dict,
additional_results: Dict,
selected_warnings: Optional[List] = None,
model_results: dict,
additional_results: dict,
selected_warnings: Optional[list] = None,
prev_ci_artifacts=None,
other_ci_artifacts=None,
):
@@ -203,15 +203,15 @@ class Message:
return f"{int(hours)}h{int(minutes)}m{int(seconds)}s"
@property
def header(self) -> Dict:
def header(self) -> dict:
return {"type": "header", "text": {"type": "plain_text", "text": self.title}}
@property
def ci_title_section(self) -> Dict:
def ci_title_section(self) -> dict:
return {"type": "section", "text": {"type": "mrkdwn", "text": self.ci_title}}
@property
def no_failures(self) -> Dict:
def no_failures(self) -> dict:
return {
"type": "section",
"text": {
@@ -227,7 +227,7 @@ class Message:
}
@property
def failures(self) -> Dict:
def failures(self) -> dict:
return {
"type": "section",
"text": {
@@ -246,7 +246,7 @@ class Message:
}
@property
def warnings(self) -> Dict:
def warnings(self) -> dict:
# If something goes wrong, let's avoid the CI report failing to be sent.
button_text = "Check warnings (Link not found)"
# Use the workflow run link
@@ -287,7 +287,7 @@ class Message:
return f"{'0'.rjust(rjust)} | {str(report['multi']).rjust(rjust)} | "
@property
def category_failures(self) -> Dict:
def category_failures(self) -> dict:
if job_name != "run_models_gpu":
category_failures_report = ""
return {"type": "section", "text": {"type": "mrkdwn", "text": category_failures_report}}
@@ -360,7 +360,7 @@ class Message:
return entries_changed
@property
def model_failures(self) -> List[Dict]:
def model_failures(self) -> list[dict]:
# Obtain per-model failures
def per_model_sum(model_category_dict):
return dicts_to_sum(model_category_dict["failed"].values())
@@ -523,7 +523,7 @@ class Message:
return model_failure_sections
@property
def additional_failures(self) -> Dict:
def additional_failures(self) -> dict:
failures = {k: v["failed"] for k, v in self.additional_results.items()}
errors = {k: v["error"] for k, v in self.additional_results.items()}
@@ -953,7 +953,7 @@ def retrieve_available_artifacts():
def add_path(self, path: str, gpu: Optional[str] = None):
self.paths.append({"name": self.name, "path": path, "gpu": gpu})
_available_artifacts: Dict[str, Artifact] = {}
_available_artifacts: dict[str, Artifact] = {}
directories = filter(os.path.isdir, os.listdir())
for directory in directories:

View File

@@ -16,7 +16,6 @@ import json
import os
import re
import time
from typing import Dict, List
from get_ci_error_statistics import get_jobs
from slack_sdk import WebClient
@@ -60,7 +59,7 @@ def extract_first_line_failure(failures_short_lines):
class Message:
def __init__(self, title: str, doc_test_results: Dict):
def __init__(self, title: str, doc_test_results: dict):
self.title = title
self.n_success = sum(job_result["n_success"] for job_result in doc_test_results.values())
@@ -90,11 +89,11 @@ class Message:
return f"{int(hours)}h{int(minutes)}m{int(seconds)}s"
@property
def header(self) -> Dict:
def header(self) -> dict:
return {"type": "header", "text": {"type": "plain_text", "text": self.title}}
@property
def no_failures(self) -> Dict:
def no_failures(self) -> dict:
return {
"type": "section",
"text": {
@@ -110,7 +109,7 @@ class Message:
}
@property
def failures(self) -> Dict:
def failures(self) -> dict:
return {
"type": "section",
"text": {
@@ -129,7 +128,7 @@ class Message:
}
@property
def category_failures(self) -> List[Dict]:
def category_failures(self) -> list[dict]:
failure_blocks = []
MAX_ERROR_TEXT = 3000 - len("The following examples had failures:\n\n\n\n") - len("[Truncated]\n")
@@ -301,7 +300,7 @@ def retrieve_available_artifacts():
def add_path(self, path: str):
self.paths.append({"name": self.name, "path": path})
_available_artifacts: Dict[str, Artifact] = {}
_available_artifacts: dict[str, Artifact] = {}
directories = filter(os.path.isdir, os.listdir())
for directory in directories:

View File

@@ -31,7 +31,6 @@ import os.path
import re
import string
from pathlib import Path
from typing import List
from git import Repo
@@ -39,7 +38,7 @@ from git import Repo
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
def get_new_python_files_between_commits(base_commit: str, commits: List[str]) -> List[str]:
def get_new_python_files_between_commits(base_commit: str, commits: list[str]) -> list[str]:
"""
Get the list of added python files between a base commit and one or several commits.
@@ -64,7 +63,7 @@ def get_new_python_files_between_commits(base_commit: str, commits: List[str]) -
return code_diff
def get_new_python_files(diff_with_last_commit=False) -> List[str]:
def get_new_python_files(diff_with_last_commit=False) -> list[str]:
"""
Return a list of python files that have been added between the current head and the main branch.

View File

@@ -59,7 +59,7 @@ import re
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Optional, Union
from git import Repo
@@ -182,7 +182,7 @@ def keep_doc_examples_only(content: str) -> str:
return "\n".join(lines_to_keep)
def get_all_tests() -> List[str]:
def get_all_tests() -> list[str]:
"""
Walks the `tests` folder to return a list of files/subfolders. This is used to split the tests to run when using
parallelism. The split is:
@@ -263,7 +263,7 @@ def diff_contains_doc_examples(repo: Repo, branching_point: str, filename: str)
return old_content_clean != new_content_clean
def get_impacted_files_from_tiny_model_summary(diff_with_last_commit: bool = False) -> List[str]:
def get_impacted_files_from_tiny_model_summary(diff_with_last_commit: bool = False) -> list[str]:
"""
Return a list of python modeling files that are impacted by the changes of `tiny_model_summary.json` in between:
@@ -379,7 +379,7 @@ def get_impacted_files_from_tiny_model_summary(diff_with_last_commit: bool = Fal
return sorted(files)
def get_diff(repo: Repo, base_commit: str, commits: List[str]) -> List[str]:
def get_diff(repo: Repo, base_commit: str, commits: list[str]) -> list[str]:
"""
Get the diff between a base commit and one or several commits.
@@ -421,7 +421,7 @@ def get_diff(repo: Repo, base_commit: str, commits: List[str]) -> List[str]:
return code_diff
def get_modified_python_files(diff_with_last_commit: bool = False) -> List[str]:
def get_modified_python_files(diff_with_last_commit: bool = False) -> list[str]:
"""
Return a list of python files that have been modified between:
@@ -451,7 +451,7 @@ def get_modified_python_files(diff_with_last_commit: bool = False) -> List[str]:
return get_diff(repo, repo.head.commit, parent_commits)
def get_diff_for_doctesting(repo: Repo, base_commit: str, commits: List[str]) -> List[str]:
def get_diff_for_doctesting(repo: Repo, base_commit: str, commits: list[str]) -> list[str]:
"""
Get the diff in doc examples between a base commit and one or several commits.
@@ -492,7 +492,7 @@ def get_diff_for_doctesting(repo: Repo, base_commit: str, commits: List[str]) ->
return code_diff
def get_all_doctest_files() -> List[str]:
def get_all_doctest_files() -> list[str]:
"""
Return the complete list of python and Markdown files on which we run doctest.
@@ -525,7 +525,7 @@ def get_all_doctest_files() -> List[str]:
return sorted(test_files_to_run)
def get_new_doctest_files(repo, base_commit, branching_commit) -> List[str]:
def get_new_doctest_files(repo, base_commit, branching_commit) -> list[str]:
"""
Get the list of files that were removed from "utils/not_doctested.txt", between `base_commit` and
`branching_commit`.
@@ -552,7 +552,7 @@ def get_new_doctest_files(repo, base_commit, branching_commit) -> List[str]:
return []
def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:
def get_doctest_files(diff_with_last_commit: bool = False) -> list[str]:
"""
Return a list of python and Markdown files where doc example have been modified between:
@@ -621,7 +621,7 @@ _re_single_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*
_re_multi_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*)\s+import\s+\(([^\)]+)\)")
def extract_imports(module_fname: str, cache: Optional[Dict[str, List[str]]] = None) -> List[str]:
def extract_imports(module_fname: str, cache: Optional[dict[str, list[str]]] = None) -> list[str]:
"""
Get the imports a given module makes.
@@ -703,7 +703,7 @@ def extract_imports(module_fname: str, cache: Optional[Dict[str, List[str]]] = N
return result
def get_module_dependencies(module_fname: str, cache: Optional[Dict[str, List[str]]] = None) -> List[str]:
def get_module_dependencies(module_fname: str, cache: Optional[dict[str, list[str]]] = None) -> list[str]:
"""
Refines the result of `extract_imports` to remove subfolders and get a proper list of module filenames: if a file
as an import `from utils import Foo, Bar`, with `utils` being a subfolder containing many files, this will traverse
@@ -786,7 +786,7 @@ def get_module_dependencies(module_fname: str, cache: Optional[Dict[str, List[st
return dependencies
def create_reverse_dependency_tree() -> List[Tuple[str, str]]:
def create_reverse_dependency_tree() -> list[tuple[str, str]]:
"""
Create a list of all edges (a, b) which mean that modifying a impacts b with a going over all module and test files.
"""
@@ -800,7 +800,7 @@ def create_reverse_dependency_tree() -> List[Tuple[str, str]]:
return list(set(edges))
def get_tree_starting_at(module: str, edges: List[Tuple[str, str]]) -> List[Union[str, List[str]]]:
def get_tree_starting_at(module: str, edges: list[tuple[str, str]]) -> list[Union[str, list[str]]]:
"""
Returns the tree starting at a given module following all edges.
@@ -861,7 +861,7 @@ def print_tree_deps_of(module, all_edges=None):
print(line[0])
def init_test_examples_dependencies() -> Tuple[Dict[str, List[str]], List[str]]:
def init_test_examples_dependencies() -> tuple[dict[str, list[str]], list[str]]:
"""
The test examples do not import from the examples (which are just scripts, not modules) so we need some extra
care initializing the dependency map, which is the goal of this function. It initializes the dependency map for
@@ -897,7 +897,7 @@ def init_test_examples_dependencies() -> Tuple[Dict[str, List[str]], List[str]]:
return test_example_deps, all_examples
def create_reverse_dependency_map() -> Dict[str, List[str]]:
def create_reverse_dependency_map() -> dict[str, list[str]]:
"""
Create the dependency map from module/test filename to the list of modules/tests that depend on it recursively.
@@ -953,8 +953,8 @@ def create_reverse_dependency_map() -> Dict[str, List[str]]:
def create_module_to_test_map(
reverse_map: Optional[Dict[str, List[str]]] = None, filter_models: bool = False
) -> Dict[str, List[str]]:
reverse_map: Optional[dict[str, list[str]]] = None, filter_models: bool = False
) -> dict[str, list[str]]:
"""
Extract the tests from the reverse_dependency_map and potentially filters the model tests.
@@ -1108,7 +1108,7 @@ def infer_tests_to_run(
f.write(" ".join(doctest_list))
def filter_tests(output_file: str, filters: List[str]):
def filter_tests(output_file: str, filters: list[str]):
"""
Reads the content of the output file and filters out all the tests in a list of given folders.
@@ -1135,7 +1135,7 @@ def filter_tests(output_file: str, filters: List[str]):
f.write(" ".join(test_files))
def parse_commit_message(commit_message: str) -> Dict[str, bool]:
def parse_commit_message(commit_message: str) -> dict[str, bool]:
"""
Parses the commit message to detect if a command is there to skip, force all or part of the CI.

View File

@@ -34,7 +34,6 @@ import collections
import os
import re
import tempfile
from typing import Dict, List, Tuple
import pandas as pd
from datasets import Dataset
@@ -124,7 +123,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
]
def camel_case_split(identifier: str) -> List[str]:
def camel_case_split(identifier: str) -> list[str]:
"""
Split a camel-cased name into words.
@@ -213,7 +212,7 @@ def get_frameworks_table() -> pd.DataFrame:
return pd.DataFrame(data)
def update_pipeline_and_auto_class_table(table: Dict[str, Tuple[str, str]]) -> Dict[str, Tuple[str, str]]:
def update_pipeline_and_auto_class_table(table: dict[str, tuple[str, str]]) -> dict[str, tuple[str, str]]:
"""
Update the table mapping models to pipelines and auto classes without removing old keys if they don't exist anymore.