Dummies multi backend (#11100)
* Replaces requires_xxx by one generic method * Quality and update check_dummies * Fix inits check * Post-merge cleanup
This commit is contained in:
@@ -22,11 +22,11 @@ import re
|
||||
# python utils/check_dummies.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
# Matches is_xxx_available()
|
||||
_re_backend = re.compile(r"is\_([a-z]*)_available()")
|
||||
# Matches from xxx import bla
|
||||
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
_re_test_backend = re.compile(r"^\s+if\s+is\_([a-z]*)\_available\(\):\s*$")
|
||||
|
||||
|
||||
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "speech", "tokenizers", "vision"]
|
||||
_re_test_backend = re.compile(r"^\s+if\s+is\_[a-z]*\_available\(\)")
|
||||
|
||||
|
||||
DUMMY_CONSTANT = """
|
||||
@@ -36,25 +36,34 @@ DUMMY_CONSTANT = """
|
||||
DUMMY_PRETRAINED_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_{1}(self)
|
||||
requires_backends(self, {1})
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_{1}(self)
|
||||
requires_backends(self, {1})
|
||||
"""
|
||||
|
||||
DUMMY_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_{1}(self)
|
||||
requires_backends(self, {1})
|
||||
"""
|
||||
|
||||
DUMMY_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_{1}({0})
|
||||
requires_backends({0}, {1})
|
||||
"""
|
||||
|
||||
|
||||
def find_backend(line):
|
||||
"""Find one (or multiple) backend in a code line of the init."""
|
||||
if _re_test_backend.search(line) is None:
|
||||
return None
|
||||
backends = [b[0] for b in _re_backend.findall(line)]
|
||||
backends.sort()
|
||||
return "_and_".join(backends)
|
||||
|
||||
|
||||
def read_init():
|
||||
""" Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
|
||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
@@ -69,14 +78,10 @@ def read_init():
|
||||
# Go through the end of the file
|
||||
while line_index < len(lines):
|
||||
# If the line is an if is_backend_available, we grab all objects associated.
|
||||
if _re_test_backend.search(lines[line_index]) is not None:
|
||||
backend = _re_test_backend.search(lines[line_index]).groups()[0]
|
||||
backend = find_backend(lines[line_index])
|
||||
if backend is not None:
|
||||
line_index += 1
|
||||
|
||||
# Ignore if backend isn't tracked for dummies.
|
||||
if backend not in BACKENDS:
|
||||
continue
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
|
||||
@@ -128,13 +133,12 @@ def create_dummy_files():
|
||||
""" Create the content of the dummy files. """
|
||||
backend_specific_objects = read_init()
|
||||
# For special correspondence backend to module name as used in the function requires_modulename
|
||||
module_names = {"torch": "pytorch"}
|
||||
dummy_files = {}
|
||||
|
||||
for backend, objects in backend_specific_objects.items():
|
||||
backend_name = module_names.get(backend, backend)
|
||||
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
|
||||
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
dummy_file += f"from ..file_utils import requires_{backend_name}\n\n"
|
||||
dummy_file += "from ..file_utils import requires_backends\n\n"
|
||||
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
|
||||
dummy_files[backend] = dummy_file
|
||||
|
||||
@@ -156,8 +160,11 @@ def check_dummies(overwrite=False):
|
||||
|
||||
actual_dummies = {}
|
||||
for backend, file_path in dummy_file_paths.items():
|
||||
with open(file_path, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_dummies[backend] = f.read()
|
||||
if os.path.isfile(file_path):
|
||||
with open(file_path, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_dummies[backend] = f.read()
|
||||
else:
|
||||
actual_dummies[backend] = ""
|
||||
|
||||
for backend in dummy_files.keys():
|
||||
if dummy_files[backend] != actual_dummies[backend]:
|
||||
|
||||
@@ -18,12 +18,14 @@ import re
|
||||
|
||||
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "speech", "tokenizers", "vision"]
|
||||
|
||||
|
||||
# Matches is_xxx_available()
|
||||
_re_backend = re.compile(r"is\_([a-z]*)_available()")
|
||||
# Catches a line with a key-values pattern: "bla": ["foo", "bar"]
|
||||
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
|
||||
# Catches a line if is_foo_available
|
||||
_re_test_backend = re.compile(r"^\s*if\s+is\_([a-z]*)\_available\(\):\s*$")
|
||||
_re_test_backend = re.compile(r"^\s*if\s+is\_[a-z]*\_available\(\)")
|
||||
# Catches a line _import_struct["bla"].append("foo")
|
||||
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
|
||||
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
|
||||
@@ -36,6 +38,15 @@ _re_between_brackets = re.compile("^\s+\[([^\]]+)\]")
|
||||
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
|
||||
|
||||
def find_backend(line):
|
||||
"""Find one (or multiple) backend in a code line of the init."""
|
||||
if _re_test_backend.search(line) is None:
|
||||
return None
|
||||
backends = [b[0] for b in _re_backend.findall(line)]
|
||||
backends.sort()
|
||||
return "_and_".join(backends)
|
||||
|
||||
|
||||
def parse_init(init_file):
|
||||
"""
|
||||
Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects
|
||||
@@ -54,7 +65,7 @@ def parse_init(init_file):
|
||||
|
||||
# First grab the objects without a specific backend in _import_structure
|
||||
objects = []
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING") and _re_test_backend.search(lines[line_index]) is None:
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING") and find_backend(lines[line_index]) is None:
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_import_struct_key_value.search(line)
|
||||
if single_line_import_search is not None:
|
||||
@@ -68,14 +79,10 @@ def parse_init(init_file):
|
||||
# Let's continue with backend-specific objects in _import_structure
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING"):
|
||||
# If the line is an if is_backend_available, we grab all objects associated.
|
||||
if _re_test_backend.search(lines[line_index]) is not None:
|
||||
backend = _re_test_backend.search(lines[line_index]).groups()[0]
|
||||
backend = find_backend(lines[line_index])
|
||||
if backend is not None:
|
||||
line_index += 1
|
||||
|
||||
# Ignore if backend isn't tracked for dummies.
|
||||
if backend not in BACKENDS:
|
||||
continue
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
|
||||
@@ -106,7 +113,7 @@ def parse_init(init_file):
|
||||
objects = []
|
||||
while (
|
||||
line_index < len(lines)
|
||||
and _re_test_backend.search(lines[line_index]) is None
|
||||
and find_backend(lines[line_index]) is None
|
||||
and not lines[line_index].startswith("else")
|
||||
):
|
||||
line = lines[line_index]
|
||||
@@ -121,14 +128,10 @@ def parse_init(init_file):
|
||||
# Let's continue with backend-specific objects
|
||||
while line_index < len(lines):
|
||||
# If the line is an if is_backemd_available, we grab all objects associated.
|
||||
if _re_test_backend.search(lines[line_index]) is not None:
|
||||
backend = _re_test_backend.search(lines[line_index]).groups()[0]
|
||||
backend = find_backend(lines[line_index])
|
||||
if backend is not None:
|
||||
line_index += 1
|
||||
|
||||
# Ignore if backend isn't tracked for dummies.
|
||||
if backend not in BACKENDS:
|
||||
continue
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
|
||||
|
||||
Reference in New Issue
Block a user