Dummies multi backend (#11100)

* Replaces requires_xxx by one generic method

* Quality and update check_dummies

* Fix inits check

* Post-merge cleanup
This commit is contained in:
Sylvain Gugger
2021-04-07 09:56:40 -04:00
committed by GitHub
parent 424419f549
commit 11505fa139
18 changed files with 1246 additions and 1275 deletions

View File

@@ -18,12 +18,14 @@ import re
PATH_TO_TRANSFORMERS = "src/transformers"
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "speech", "tokenizers", "vision"]
# Matches is_xxx_available()
_re_backend = re.compile(r"is\_([a-z]*)_available()")
# Catches a line with a key-values pattern: "bla": ["foo", "bar"]
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
# Catches a line if is_foo_available
_re_test_backend = re.compile(r"^\s*if\s+is\_([a-z]*)\_available\(\):\s*$")
_re_test_backend = re.compile(r"^\s*if\s+is\_[a-z]*\_available\(\)")
# Catches a line _import_struct["bla"].append("foo")
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
@@ -36,6 +38,15 @@ _re_between_brackets = re.compile("^\s+\[([^\]]+)\]")
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
def find_backend(line):
"""Find one (or multiple) backend in a code line of the init."""
if _re_test_backend.search(line) is None:
return None
backends = [b[0] for b in _re_backend.findall(line)]
backends.sort()
return "_and_".join(backends)
def parse_init(init_file):
"""
Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects
@@ -54,7 +65,7 @@ def parse_init(init_file):
# First grab the objects without a specific backend in _import_structure
objects = []
while not lines[line_index].startswith("if TYPE_CHECKING") and _re_test_backend.search(lines[line_index]) is None:
while not lines[line_index].startswith("if TYPE_CHECKING") and find_backend(lines[line_index]) is None:
line = lines[line_index]
single_line_import_search = _re_import_struct_key_value.search(line)
if single_line_import_search is not None:
@@ -68,14 +79,10 @@ def parse_init(init_file):
# Let's continue with backend-specific objects in _import_structure
while not lines[line_index].startswith("if TYPE_CHECKING"):
# If the line is an if is_backend_available, we grab all objects associated.
if _re_test_backend.search(lines[line_index]) is not None:
backend = _re_test_backend.search(lines[line_index]).groups()[0]
backend = find_backend(lines[line_index])
if backend is not None:
line_index += 1
# Ignore if backend isn't tracked for dummies.
if backend not in BACKENDS:
continue
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
@@ -106,7 +113,7 @@ def parse_init(init_file):
objects = []
while (
line_index < len(lines)
and _re_test_backend.search(lines[line_index]) is None
and find_backend(lines[line_index]) is None
and not lines[line_index].startswith("else")
):
line = lines[line_index]
@@ -121,14 +128,10 @@ def parse_init(init_file):
# Let's continue with backend-specific objects
while line_index < len(lines):
# If the line is an if is_backemd_available, we grab all objects associated.
if _re_test_backend.search(lines[line_index]) is not None:
backend = _re_test_backend.search(lines[line_index]).groups()[0]
backend = find_backend(lines[line_index])
if backend is not None:
line_index += 1
# Ignore if backend isn't tracked for dummies.
if backend not in BACKENDS:
continue
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):