Transformers fast import part 2 (#9446)
* Main init work * Add version * Change from absolute to relative imports * Fix imports * One more typo * More typos * Styling * Make quality script pass * Add necessary replace in template * Fix typos * Spaces are ignored in replace for some reason * Forgot one models. * Fixes for import Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr> * Add documentation * Styling Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -23,237 +23,79 @@ import re
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
_re_test_backend = re.compile(r"^\s+if\s+is\_([a-z]*)\_available\(\):\s*$")
|
||||
|
||||
|
||||
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers"]
|
||||
|
||||
|
||||
DUMMY_CONSTANT = """
|
||||
{0} = None
|
||||
"""
|
||||
|
||||
DUMMY_PT_PRETRAINED_CLASS = """
|
||||
DUMMY_PRETRAINED_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
requires_{1}(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
requires_{1}(self)
|
||||
"""
|
||||
|
||||
DUMMY_PT_CLASS = """
|
||||
DUMMY_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
requires_{1}(self)
|
||||
"""
|
||||
|
||||
DUMMY_PT_FUNCTION = """
|
||||
DUMMY_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_pytorch({0})
|
||||
requires_{1}({0})
|
||||
"""
|
||||
|
||||
|
||||
DUMMY_TF_PRETRAINED_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
"""
|
||||
|
||||
DUMMY_TF_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
"""
|
||||
|
||||
DUMMY_TF_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_tf({0})
|
||||
"""
|
||||
|
||||
|
||||
DUMMY_FLAX_PRETRAINED_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
"""
|
||||
|
||||
DUMMY_FLAX_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
"""
|
||||
|
||||
DUMMY_FLAX_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_flax({0})
|
||||
"""
|
||||
|
||||
|
||||
DUMMY_SENTENCEPIECE_PRETRAINED_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_sentencepiece(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_sentencepiece(self)
|
||||
"""
|
||||
|
||||
DUMMY_SENTENCEPIECE_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_sentencepiece(self)
|
||||
"""
|
||||
|
||||
DUMMY_SENTENCEPIECE_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_sentencepiece({0})
|
||||
"""
|
||||
|
||||
|
||||
DUMMY_TOKENIZERS_PRETRAINED_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tokenizers(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tokenizers(self)
|
||||
"""
|
||||
|
||||
DUMMY_TOKENIZERS_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tokenizers(self)
|
||||
"""
|
||||
|
||||
DUMMY_TOKENIZERS_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_tokenizers({0})
|
||||
"""
|
||||
|
||||
# Map all these to dummy type
|
||||
|
||||
DUMMY_PRETRAINED_CLASS = {
|
||||
"pt": DUMMY_PT_PRETRAINED_CLASS,
|
||||
"tf": DUMMY_TF_PRETRAINED_CLASS,
|
||||
"flax": DUMMY_FLAX_PRETRAINED_CLASS,
|
||||
"sentencepiece": DUMMY_SENTENCEPIECE_PRETRAINED_CLASS,
|
||||
"tokenizers": DUMMY_TOKENIZERS_PRETRAINED_CLASS,
|
||||
}
|
||||
|
||||
DUMMY_CLASS = {
|
||||
"pt": DUMMY_PT_CLASS,
|
||||
"tf": DUMMY_TF_CLASS,
|
||||
"flax": DUMMY_FLAX_CLASS,
|
||||
"sentencepiece": DUMMY_SENTENCEPIECE_CLASS,
|
||||
"tokenizers": DUMMY_TOKENIZERS_CLASS,
|
||||
}
|
||||
|
||||
DUMMY_FUNCTION = {
|
||||
"pt": DUMMY_PT_FUNCTION,
|
||||
"tf": DUMMY_TF_FUNCTION,
|
||||
"flax": DUMMY_FLAX_FUNCTION,
|
||||
"sentencepiece": DUMMY_SENTENCEPIECE_FUNCTION,
|
||||
"tokenizers": DUMMY_TOKENIZERS_FUNCTION,
|
||||
}
|
||||
|
||||
|
||||
def read_init():
|
||||
""" Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
|
||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Get to the point we do the actual imports for type checking
|
||||
line_index = 0
|
||||
# Find where the SentencePiece imports begin
|
||||
sentencepiece_objects = []
|
||||
while not lines[line_index].startswith("if is_sentencepiece_available():"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
|
||||
# Until we unindent, add SentencePiece objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
||||
line = lines[line_index]
|
||||
search = _re_single_line_import.search(line)
|
||||
if search is not None:
|
||||
sentencepiece_objects += search.groups()[0].split(", ")
|
||||
elif line.startswith(" "):
|
||||
sentencepiece_objects.append(line[8:-2])
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING"):
|
||||
line_index += 1
|
||||
|
||||
# Find where the Tokenizers imports begin
|
||||
tokenizers_objects = []
|
||||
while not lines[line_index].startswith("if is_tokenizers_available():"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
backend_specific_objects = {}
|
||||
# Go through the end of the file
|
||||
while line_index < len(lines):
|
||||
# If the line is an if is_backemd_available, we grab all objects associated.
|
||||
if _re_test_backend.search(lines[line_index]) is not None:
|
||||
backend = _re_test_backend.search(lines[line_index]).groups()[0]
|
||||
line_index += 1
|
||||
|
||||
# Until we unindent, add Tokenizers objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
||||
line = lines[line_index]
|
||||
search = _re_single_line_import.search(line)
|
||||
if search is not None:
|
||||
tokenizers_objects += search.groups()[0].split(", ")
|
||||
elif line.startswith(" "):
|
||||
tokenizers_objects.append(line[8:-2])
|
||||
line_index += 1
|
||||
# Ignore if backend isn't tracked for dummies.
|
||||
if backend not in BACKENDS:
|
||||
continue
|
||||
|
||||
# Find where the PyTorch imports begin
|
||||
pt_objects = []
|
||||
while not lines[line_index].startswith("if is_torch_available():"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_single_line_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
objects.extend(single_line_import_search.groups()[0].split(", "))
|
||||
elif line.startswith(" " * 12):
|
||||
objects.append(line[12:-2])
|
||||
line_index += 1
|
||||
|
||||
# Until we unindent, add PyTorch objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
||||
line = lines[line_index]
|
||||
search = _re_single_line_import.search(line)
|
||||
if search is not None:
|
||||
pt_objects += search.groups()[0].split(", ")
|
||||
elif line.startswith(" "):
|
||||
pt_objects.append(line[8:-2])
|
||||
line_index += 1
|
||||
backend_specific_objects[backend] = objects
|
||||
else:
|
||||
line_index += 1
|
||||
|
||||
# Find where the TF imports begin
|
||||
tf_objects = []
|
||||
while not lines[line_index].startswith("if is_tf_available():"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
|
||||
# Until we unindent, add PyTorch objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
||||
line = lines[line_index]
|
||||
search = _re_single_line_import.search(line)
|
||||
if search is not None:
|
||||
tf_objects += search.groups()[0].split(", ")
|
||||
elif line.startswith(" "):
|
||||
tf_objects.append(line[8:-2])
|
||||
line_index += 1
|
||||
|
||||
# Find where the FLAX imports begin
|
||||
flax_objects = []
|
||||
while not lines[line_index].startswith("if is_flax_available():"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
|
||||
# Until we unindent, add PyTorch objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
||||
line = lines[line_index]
|
||||
search = _re_single_line_import.search(line)
|
||||
if search is not None:
|
||||
flax_objects += search.groups()[0].split(", ")
|
||||
elif line.startswith(" "):
|
||||
flax_objects.append(line[8:-2])
|
||||
line_index += 1
|
||||
|
||||
return sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects
|
||||
return backend_specific_objects
|
||||
|
||||
|
||||
def create_dummy_object(name, type="pt"):
|
||||
def create_dummy_object(name, backend_name):
|
||||
""" Create the code for the dummy object corresponding to `name`."""
|
||||
_pretrained = [
|
||||
"Config" "ForCausalLM",
|
||||
@@ -266,11 +108,10 @@ def create_dummy_object(name, type="pt"):
|
||||
"Model",
|
||||
"Tokenizer",
|
||||
]
|
||||
assert type in ["pt", "tf", "sentencepiece", "tokenizers", "flax"]
|
||||
if name.isupper():
|
||||
return DUMMY_CONSTANT.format(name)
|
||||
elif name.islower():
|
||||
return (DUMMY_FUNCTION[type]).format(name)
|
||||
return DUMMY_FUNCTION.format(name, backend_name)
|
||||
else:
|
||||
is_pretrained = False
|
||||
for part in _pretrained:
|
||||
@@ -278,114 +119,61 @@ def create_dummy_object(name, type="pt"):
|
||||
is_pretrained = True
|
||||
break
|
||||
if is_pretrained:
|
||||
template = DUMMY_PRETRAINED_CLASS[type]
|
||||
return DUMMY_PRETRAINED_CLASS.format(name, backend_name)
|
||||
else:
|
||||
template = DUMMY_CLASS[type]
|
||||
return template.format(name)
|
||||
return DUMMY_CLASS.format(name, backend_name)
|
||||
|
||||
|
||||
def create_dummy_files():
|
||||
""" Create the content of the dummy files. """
|
||||
sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects = read_init()
|
||||
backend_specific_objects = read_init()
|
||||
# For special correspondence backend to module name as used in the function requires_modulename
|
||||
module_names = {"torch": "pytorch"}
|
||||
dummy_files = {}
|
||||
|
||||
sentencepiece_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
sentencepiece_dummies += "from ..file_utils import requires_sentencepiece\n\n"
|
||||
sentencepiece_dummies += "\n".join([create_dummy_object(o, type="sentencepiece") for o in sentencepiece_objects])
|
||||
for backend, objects in backend_specific_objects.items():
|
||||
backend_name = module_names.get(backend, backend)
|
||||
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
dummy_file += f"from ..file_utils import requires_{backend_name}\n\n"
|
||||
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
|
||||
dummy_files[backend] = dummy_file
|
||||
|
||||
tokenizers_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
tokenizers_dummies += "from ..file_utils import requires_tokenizers\n\n"
|
||||
tokenizers_dummies += "\n".join([create_dummy_object(o, type="tokenizers") for o in tokenizers_objects])
|
||||
|
||||
pt_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
pt_dummies += "from ..file_utils import requires_pytorch\n\n"
|
||||
pt_dummies += "\n".join([create_dummy_object(o, type="pt") for o in pt_objects])
|
||||
|
||||
tf_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
tf_dummies += "from ..file_utils import requires_tf\n\n"
|
||||
tf_dummies += "\n".join([create_dummy_object(o, type="tf") for o in tf_objects])
|
||||
|
||||
flax_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
flax_dummies += "from ..file_utils import requires_flax\n\n"
|
||||
flax_dummies += "\n".join([create_dummy_object(o, type="flax") for o in flax_objects])
|
||||
|
||||
return sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies
|
||||
return dummy_files
|
||||
|
||||
|
||||
def check_dummies(overwrite=False):
|
||||
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """
|
||||
sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies = create_dummy_files()
|
||||
dummy_files = create_dummy_files()
|
||||
# For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
|
||||
short_names = {"torch": "pt"}
|
||||
|
||||
# Locate actual dummy modules and read their content.
|
||||
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
|
||||
sentencepiece_file = os.path.join(path, "dummy_sentencepiece_objects.py")
|
||||
tokenizers_file = os.path.join(path, "dummy_tokenizers_objects.py")
|
||||
pt_file = os.path.join(path, "dummy_pt_objects.py")
|
||||
tf_file = os.path.join(path, "dummy_tf_objects.py")
|
||||
flax_file = os.path.join(path, "dummy_flax_objects.py")
|
||||
dummy_file_paths = {
|
||||
backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
|
||||
for backend in dummy_files.keys()
|
||||
}
|
||||
|
||||
with open(sentencepiece_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_sentencepiece_dummies = f.read()
|
||||
with open(tokenizers_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_tokenizers_dummies = f.read()
|
||||
with open(pt_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_pt_dummies = f.read()
|
||||
with open(tf_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_tf_dummies = f.read()
|
||||
with open(flax_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_flax_dummies = f.read()
|
||||
actual_dummies = {}
|
||||
for backend, file_path in dummy_file_paths.items():
|
||||
with open(file_path, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_dummies[backend] = f.read()
|
||||
|
||||
if sentencepiece_dummies != actual_sentencepiece_dummies:
|
||||
if overwrite:
|
||||
print("Updating transformers.utils.dummy_sentencepiece_objects.py as the main __init__ has new objects.")
|
||||
with open(sentencepiece_file, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(sentencepiece_dummies)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in transformers.utils.dummy_sentencepiece_objects.py.",
|
||||
"Run `make fix-copies` to fix this.",
|
||||
)
|
||||
|
||||
if tokenizers_dummies != actual_tokenizers_dummies:
|
||||
if overwrite:
|
||||
print("Updating transformers.utils.dummy_tokenizers_objects.py as the main __init__ has new objects.")
|
||||
with open(tokenizers_file, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(tokenizers_dummies)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in transformers.utils.dummy_tokenizers_objects.py.",
|
||||
"Run `make fix-copies` to fix this.",
|
||||
)
|
||||
|
||||
if pt_dummies != actual_pt_dummies:
|
||||
if overwrite:
|
||||
print("Updating transformers.utils.dummy_pt_objects.py as the main __init__ has new objects.")
|
||||
with open(pt_file, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(pt_dummies)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.",
|
||||
"Run `make fix-copies` to fix this.",
|
||||
)
|
||||
|
||||
if tf_dummies != actual_tf_dummies:
|
||||
if overwrite:
|
||||
print("Updating transformers.utils.dummy_tf_objects.py as the main __init__ has new objects.")
|
||||
with open(tf_file, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(tf_dummies)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.",
|
||||
"Run `make fix-copies` to fix this.",
|
||||
)
|
||||
|
||||
if flax_dummies != actual_flax_dummies:
|
||||
if overwrite:
|
||||
print("Updating transformers.utils.dummy_flax_objects.py as the main __init__ has new objects.")
|
||||
with open(flax_file, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(flax_dummies)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in transformers.utils.dummy_flax_objects.py.",
|
||||
"Run `make fix-copies` to fix this.",
|
||||
)
|
||||
for backend in dummy_files.keys():
|
||||
if dummy_files[backend] != actual_dummies[backend]:
|
||||
if overwrite:
|
||||
print(
|
||||
f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
|
||||
"__init__ has new objects."
|
||||
)
|
||||
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(dummy_files[backend])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in "
|
||||
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
|
||||
"to fix this."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -413,9 +413,6 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
|
||||
def ignore_undocumented(name):
|
||||
"""Rules to determine if `name` should be undocumented."""
|
||||
# NOT DOCUMENTED ON PURPOSE.
|
||||
# Magic attributes are not documented.
|
||||
if name.startswith("__"):
|
||||
return True
|
||||
# Constants uppercase are not documented.
|
||||
if name.isupper():
|
||||
return True
|
||||
@@ -459,7 +456,9 @@ def ignore_undocumented(name):
|
||||
def check_all_objects_are_documented():
|
||||
""" Check all models are properly documented."""
|
||||
documented_objs = find_all_documented_objects()
|
||||
undocumented_objs = [c for c in dir(transformers) if c not in documented_objs and not ignore_undocumented(c)]
|
||||
modules = transformers._modules
|
||||
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
|
||||
undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)]
|
||||
if len(undocumented_objs) > 0:
|
||||
raise Exception(
|
||||
"The following objects are in the public init so should be documented:\n - "
|
||||
|
||||
Reference in New Issue
Block a user