Better dummies (#15148)
* Better dummies * See if this fixes the issue * Fix quality * Style * Add doc for DummyObject
This commit is contained in:
@@ -33,46 +33,15 @@ DUMMY_CONSTANT = """
|
||||
{0} = None
|
||||
"""
|
||||
|
||||
DUMMY_PRETRAINED_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, {1})
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, {1})
|
||||
"""
|
||||
|
||||
PT_DUMMY_PRETRAINED_CLASS = (
|
||||
DUMMY_PRETRAINED_CLASS
|
||||
+ """
|
||||
def forward(self, *args, **kwargs):
|
||||
requires_backends(self, {1})
|
||||
"""
|
||||
)
|
||||
|
||||
TF_DUMMY_PRETRAINED_CLASS = (
|
||||
DUMMY_PRETRAINED_CLASS
|
||||
+ """
|
||||
def call(self, *args, **kwargs):
|
||||
requires_backends(self, {1})
|
||||
"""
|
||||
)
|
||||
|
||||
FLAX_DUMMY_PRETRAINED_CLASS = (
|
||||
DUMMY_PRETRAINED_CLASS
|
||||
+ """
|
||||
def __call__(self, *args, **kwargs):
|
||||
requires_backends(self, {1})
|
||||
"""
|
||||
)
|
||||
|
||||
DUMMY_CLASS = """
|
||||
class {0}:
|
||||
class {0}(metaclass=DummyObject):
|
||||
_backends = {1}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, {1})
|
||||
"""
|
||||
|
||||
|
||||
DUMMY_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_backends({0}, {1})
|
||||
@@ -126,45 +95,12 @@ def read_init():
|
||||
|
||||
def create_dummy_object(name, backend_name):
|
||||
"""Create the code for the dummy object corresponding to `name`."""
|
||||
_models = [
|
||||
"ForCausalLM",
|
||||
"ForConditionalGeneration",
|
||||
"ForMaskedLM",
|
||||
"ForMultipleChoice",
|
||||
"ForNextSentencePrediction",
|
||||
"ForObjectDetection",
|
||||
"ForQuestionAnswering",
|
||||
"ForSegmentation",
|
||||
"ForSequenceClassification",
|
||||
"ForTokenClassification",
|
||||
"Model",
|
||||
]
|
||||
_pretrained = ["Config", "Tokenizer", "Processor"]
|
||||
if name.isupper():
|
||||
return DUMMY_CONSTANT.format(name)
|
||||
elif name.islower():
|
||||
return DUMMY_FUNCTION.format(name, backend_name)
|
||||
else:
|
||||
is_model = False
|
||||
for part in _models:
|
||||
if part in name:
|
||||
is_model = True
|
||||
break
|
||||
if is_model:
|
||||
if name.startswith("TF"):
|
||||
return TF_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
|
||||
if name.startswith("Flax"):
|
||||
return FLAX_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
|
||||
return PT_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
|
||||
is_pretrained = False
|
||||
for part in _pretrained:
|
||||
if part in name:
|
||||
is_pretrained = True
|
||||
break
|
||||
if is_pretrained:
|
||||
return DUMMY_PRETRAINED_CLASS.format(name, backend_name)
|
||||
else:
|
||||
return DUMMY_CLASS.format(name, backend_name)
|
||||
return DUMMY_CLASS.format(name, backend_name)
|
||||
|
||||
|
||||
def create_dummy_files():
|
||||
@@ -176,7 +112,8 @@ def create_dummy_files():
|
||||
for backend, objects in backend_specific_objects.items():
|
||||
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
|
||||
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
dummy_file += "from ..file_utils import requires_backends\n\n"
|
||||
dummy_file += "# flake8: noqa\n"
|
||||
dummy_file += "from ..file_utils import DummyObject, requires_backends\n\n"
|
||||
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
|
||||
dummy_files[backend] = dummy_file
|
||||
|
||||
|
||||
Reference in New Issue
Block a user