Dummies multi backend (#11100)
* Replaces requires_xxx by one generic method * Quality and update check_dummies * Fix inits check * Post-merge cleanup
This commit is contained in:
@@ -18,12 +18,14 @@ import re
|
||||
|
||||
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "speech", "tokenizers", "vision"]
|
||||
|
||||
|
||||
# Matches is_xxx_available()
|
||||
_re_backend = re.compile(r"is\_([a-z]*)_available()")
|
||||
# Catches a line with a key-values pattern: "bla": ["foo", "bar"]
|
||||
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
|
||||
# Catches a line if is_foo_available
|
||||
_re_test_backend = re.compile(r"^\s*if\s+is\_([a-z]*)\_available\(\):\s*$")
|
||||
_re_test_backend = re.compile(r"^\s*if\s+is\_[a-z]*\_available\(\)")
|
||||
# Catches a line _import_struct["bla"].append("foo")
|
||||
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
|
||||
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
|
||||
@@ -36,6 +38,15 @@ _re_between_brackets = re.compile("^\s+\[([^\]]+)\]")
|
||||
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
|
||||
|
||||
def find_backend(line):
|
||||
"""Find one (or multiple) backend in a code line of the init."""
|
||||
if _re_test_backend.search(line) is None:
|
||||
return None
|
||||
backends = [b[0] for b in _re_backend.findall(line)]
|
||||
backends.sort()
|
||||
return "_and_".join(backends)
|
||||
|
||||
|
||||
def parse_init(init_file):
|
||||
"""
|
||||
Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects
|
||||
@@ -54,7 +65,7 @@ def parse_init(init_file):
|
||||
|
||||
# First grab the objects without a specific backend in _import_structure
|
||||
objects = []
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING") and _re_test_backend.search(lines[line_index]) is None:
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING") and find_backend(lines[line_index]) is None:
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_import_struct_key_value.search(line)
|
||||
if single_line_import_search is not None:
|
||||
@@ -68,14 +79,10 @@ def parse_init(init_file):
|
||||
# Let's continue with backend-specific objects in _import_structure
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING"):
|
||||
# If the line is an if is_backend_available, we grab all objects associated.
|
||||
if _re_test_backend.search(lines[line_index]) is not None:
|
||||
backend = _re_test_backend.search(lines[line_index]).groups()[0]
|
||||
backend = find_backend(lines[line_index])
|
||||
if backend is not None:
|
||||
line_index += 1
|
||||
|
||||
# Ignore if backend isn't tracked for dummies.
|
||||
if backend not in BACKENDS:
|
||||
continue
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
|
||||
@@ -106,7 +113,7 @@ def parse_init(init_file):
|
||||
objects = []
|
||||
while (
|
||||
line_index < len(lines)
|
||||
and _re_test_backend.search(lines[line_index]) is None
|
||||
and find_backend(lines[line_index]) is None
|
||||
and not lines[line_index].startswith("else")
|
||||
):
|
||||
line = lines[line_index]
|
||||
@@ -121,14 +128,10 @@ def parse_init(init_file):
|
||||
# Let's continue with backend-specific objects
|
||||
while line_index < len(lines):
|
||||
# If the line is an if is_backemd_available, we grab all objects associated.
|
||||
if _re_test_backend.search(lines[line_index]) is not None:
|
||||
backend = _re_test_backend.search(lines[line_index]).groups()[0]
|
||||
backend = find_backend(lines[line_index])
|
||||
if backend is not None:
|
||||
line_index += 1
|
||||
|
||||
# Ignore if backend isn't tracked for dummies.
|
||||
if backend not in BACKENDS:
|
||||
continue
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
|
||||
|
||||
Reference in New Issue
Block a user