Allow soft dependencies in the namespace with ImportErrors at use (#7537)
* PoC on RAG * Format class name/obj name * Better name in message * PoC on one TF model * Add PyTorch and TF dummy objects + script * Treat scikit-learn * Bad copy pastes * Typo
This commit is contained in:
199
utils/check_dummies.py
Normal file
199
utils/check_dummies.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_dummies.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
|
||||
DUMMY_CONSTANT = """
|
||||
{0} = None
|
||||
"""
|
||||
|
||||
DUMMY_PT_PRETRAINED_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
"""
|
||||
|
||||
DUMMY_PT_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
"""
|
||||
|
||||
DUMMY_PT_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_pytorch({0})
|
||||
"""
|
||||
|
||||
DUMMY_TF_PRETRAINED_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
"""
|
||||
|
||||
DUMMY_TF_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
"""
|
||||
|
||||
DUMMY_TF_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_tf({0})
|
||||
"""
|
||||
|
||||
|
||||
def read_init():
|
||||
""" Read the init and exctracts PyTorch and TensorFlow objects. """
|
||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
line_index = 0
|
||||
# Find where the PyTorch imports begin
|
||||
pt_objects = []
|
||||
while not lines[line_index].startswith("if is_torch_available():"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
|
||||
# Until we unindent, add PyTorch objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
||||
line = lines[line_index]
|
||||
search = _re_single_line_import.search(line)
|
||||
if search is not None:
|
||||
pt_objects += search.groups()[0].split(", ")
|
||||
elif line.startswith(" "):
|
||||
pt_objects.append(line[8:-2])
|
||||
line_index += 1
|
||||
|
||||
# Find where the TF imports begin
|
||||
tf_objects = []
|
||||
while not lines[line_index].startswith("if is_tf_available():"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
|
||||
# Until we unindent, add PyTorch objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
||||
line = lines[line_index]
|
||||
search = _re_single_line_import.search(line)
|
||||
if search is not None:
|
||||
tf_objects += search.groups()[0].split(", ")
|
||||
elif line.startswith(" "):
|
||||
tf_objects.append(line[8:-2])
|
||||
line_index += 1
|
||||
return pt_objects, tf_objects
|
||||
|
||||
|
||||
def create_dummy_object(name, is_pytorch=True):
|
||||
""" Create the code for the dummy object corresponding to `name`."""
|
||||
_pretrained = [
|
||||
"Config" "ForCausalLM",
|
||||
"ForConditionalGeneration",
|
||||
"ForMaskedLM",
|
||||
"ForMultipleChoice",
|
||||
"ForQuestionAnswering",
|
||||
"ForSequenceClassification",
|
||||
"ForTokenClassification",
|
||||
"Model",
|
||||
"Tokenizer",
|
||||
]
|
||||
if name.isupper():
|
||||
return DUMMY_CONSTANT.format(name)
|
||||
elif name.islower():
|
||||
return (DUMMY_PT_FUNCTION if is_pytorch else DUMMY_TF_FUNCTION).format(name)
|
||||
else:
|
||||
is_pretrained = False
|
||||
for part in _pretrained:
|
||||
if part in name:
|
||||
is_pretrained = True
|
||||
break
|
||||
if is_pretrained:
|
||||
template = DUMMY_PT_PRETRAINED_CLASS if is_pytorch else DUMMY_TF_PRETRAINED_CLASS
|
||||
else:
|
||||
template = DUMMY_PT_CLASS if is_pytorch else DUMMY_TF_CLASS
|
||||
return template.format(name)
|
||||
|
||||
|
||||
def create_dummy_files():
|
||||
""" Create the content of the dummy files. """
|
||||
pt_objects, tf_objects = read_init()
|
||||
|
||||
pt_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
pt_dummies += "from ..file_utils import requires_pytorch\n\n"
|
||||
pt_dummies += "\n".join([create_dummy_object(o) for o in pt_objects])
|
||||
|
||||
tf_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
tf_dummies += "from ..file_utils import requires_tf\n\n"
|
||||
tf_dummies += "\n".join([create_dummy_object(o, False) for o in tf_objects])
|
||||
|
||||
return pt_dummies, tf_dummies
|
||||
|
||||
|
||||
def check_dummies(overwrite=False):
|
||||
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """
|
||||
pt_dummies, tf_dummies = create_dummy_files()
|
||||
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
|
||||
pt_file = os.path.join(path, "dummy_pt_objects.py")
|
||||
tf_file = os.path.join(path, "dummy_tf_objects.py")
|
||||
|
||||
with open(pt_file, "r", encoding="utf-8") as f:
|
||||
actual_pt_dummies = f.read()
|
||||
with open(tf_file, "r", encoding="utf-8") as f:
|
||||
actual_tf_dummies = f.read()
|
||||
|
||||
if pt_dummies != actual_pt_dummies:
|
||||
if overwrite:
|
||||
print("Updating transformers.utils.dummy_pt_objects.py as the main __init__ has new objects.")
|
||||
with open(pt_file, "w", encoding="utf-8") as f:
|
||||
f.write(pt_dummies)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.",
|
||||
"Run `make fix-copies` to fix this.",
|
||||
)
|
||||
|
||||
if tf_dummies != actual_tf_dummies:
|
||||
if overwrite:
|
||||
print("Updating transformers.utils.dummy_tf_objects.py as the main __init__ has new objects.")
|
||||
with open(tf_file, "w", encoding="utf-8") as f:
|
||||
f.write(tf_dummies)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.",
|
||||
"Run `make fix-copies` to fix this.",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
|
||||
check_dummies(args.fix_and_overwrite)
|
||||
Reference in New Issue
Block a user