Dummies multi backend (#11100)

* Replaces requires_xxx by one generic method

* Quality and update check_dummies

* Fix inits check

* Post-merge cleanup
This commit is contained in:
Sylvain Gugger
2021-04-07 09:56:40 -04:00
committed by GitHub
parent 424419f549
commit 11505fa139
18 changed files with 1246 additions and 1275 deletions

View File

@@ -22,11 +22,11 @@ import re
# python utils/check_dummies.py
PATH_TO_TRANSFORMERS = "src/transformers"
# Matches is_xxx_available()
_re_backend = re.compile(r"is\_([a-z]*)_available()")
# Matches from xxx import bla
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
_re_test_backend = re.compile(r"^\s+if\s+is\_([a-z]*)\_available\(\):\s*$")
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "speech", "tokenizers", "vision"]
_re_test_backend = re.compile(r"^\s+if\s+is\_[a-z]*\_available\(\)")
DUMMY_CONSTANT = """
@@ -36,25 +36,34 @@ DUMMY_CONSTANT = """
DUMMY_PRETRAINED_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_{1}(self)
requires_backends(self, {1})
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_{1}(self)
requires_backends(self, {1})
"""
DUMMY_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_{1}(self)
requires_backends(self, {1})
"""
DUMMY_FUNCTION = """
def {0}(*args, **kwargs):
requires_{1}({0})
requires_backends({0}, {1})
"""
def find_backend(line):
"""Find one (or multiple) backend in a code line of the init."""
if _re_test_backend.search(line) is None:
return None
backends = [b[0] for b in _re_backend.findall(line)]
backends.sort()
return "_and_".join(backends)
def read_init():
""" Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
@@ -69,14 +78,10 @@ def read_init():
# Go through the end of the file
while line_index < len(lines):
# If the line is an if is_backend_available, we grab all objects associated.
if _re_test_backend.search(lines[line_index]) is not None:
backend = _re_test_backend.search(lines[line_index]).groups()[0]
backend = find_backend(lines[line_index])
if backend is not None:
line_index += 1
# Ignore if backend isn't tracked for dummies.
if backend not in BACKENDS:
continue
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
@@ -128,13 +133,12 @@ def create_dummy_files():
""" Create the content of the dummy files. """
backend_specific_objects = read_init()
# For special correspondence backend to module name as used in the function requires_modulename
module_names = {"torch": "pytorch"}
dummy_files = {}
for backend, objects in backend_specific_objects.items():
backend_name = module_names.get(backend, backend)
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
dummy_file += f"from ..file_utils import requires_{backend_name}\n\n"
dummy_file += "from ..file_utils import requires_backends\n\n"
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
dummy_files[backend] = dummy_file
@@ -156,8 +160,11 @@ def check_dummies(overwrite=False):
actual_dummies = {}
for backend, file_path in dummy_file_paths.items():
with open(file_path, "r", encoding="utf-8", newline="\n") as f:
actual_dummies[backend] = f.read()
if os.path.isfile(file_path):
with open(file_path, "r", encoding="utf-8", newline="\n") as f:
actual_dummies[backend] = f.read()
else:
actual_dummies[backend] = ""
for backend in dummy_files.keys():
if dummy_files[backend] != actual_dummies[backend]:

View File

@@ -18,12 +18,14 @@ import re
PATH_TO_TRANSFORMERS = "src/transformers"
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "speech", "tokenizers", "vision"]
# Matches is_xxx_available()
_re_backend = re.compile(r"is\_([a-z]*)_available()")
# Catches a line with a key-values pattern: "bla": ["foo", "bar"]
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
# Catches a line if is_foo_available
_re_test_backend = re.compile(r"^\s*if\s+is\_([a-z]*)\_available\(\):\s*$")
_re_test_backend = re.compile(r"^\s*if\s+is\_[a-z]*\_available\(\)")
# Catches a line _import_struct["bla"].append("foo")
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
@@ -36,6 +38,15 @@ _re_between_brackets = re.compile("^\s+\[([^\]]+)\]")
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
def find_backend(line):
"""Find one (or multiple) backend in a code line of the init."""
if _re_test_backend.search(line) is None:
return None
backends = [b[0] for b in _re_backend.findall(line)]
backends.sort()
return "_and_".join(backends)
def parse_init(init_file):
"""
Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects
@@ -54,7 +65,7 @@ def parse_init(init_file):
# First grab the objects without a specific backend in _import_structure
objects = []
while not lines[line_index].startswith("if TYPE_CHECKING") and _re_test_backend.search(lines[line_index]) is None:
while not lines[line_index].startswith("if TYPE_CHECKING") and find_backend(lines[line_index]) is None:
line = lines[line_index]
single_line_import_search = _re_import_struct_key_value.search(line)
if single_line_import_search is not None:
@@ -68,14 +79,10 @@ def parse_init(init_file):
# Let's continue with backend-specific objects in _import_structure
while not lines[line_index].startswith("if TYPE_CHECKING"):
# If the line is an if is_backend_available, we grab all objects associated.
if _re_test_backend.search(lines[line_index]) is not None:
backend = _re_test_backend.search(lines[line_index]).groups()[0]
backend = find_backend(lines[line_index])
if backend is not None:
line_index += 1
# Ignore if backend isn't tracked for dummies.
if backend not in BACKENDS:
continue
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
@@ -106,7 +113,7 @@ def parse_init(init_file):
objects = []
while (
line_index < len(lines)
and _re_test_backend.search(lines[line_index]) is None
and find_backend(lines[line_index]) is None
and not lines[line_index].startswith("else")
):
line = lines[line_index]
@@ -121,14 +128,10 @@ def parse_init(init_file):
# Let's continue with backend-specific objects
while line_index < len(lines):
# If the line is an if is_backemd_available, we grab all objects associated.
if _re_test_backend.search(lines[line_index]) is not None:
backend = _re_test_backend.search(lines[line_index]).groups()[0]
backend = find_backend(lines[line_index])
if backend is not None:
line_index += 1
# Ignore if backend isn't tracked for dummies.
if backend not in BACKENDS:
continue
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):