Better dummies (#15148)

* Better dummies

* See if this fixes the issue

* Fix quality

* Style

* Add doc for DummyObject
This commit is contained in:
Sylvain Gugger
2022-01-14 10:59:41 -05:00
committed by GitHub
parent b212ff9f49
commit 1b730c3d11
15 changed files with 3287 additions and 6876 deletions

View File

@@ -33,46 +33,15 @@ DUMMY_CONSTANT = """
{0} = None
"""
DUMMY_PRETRAINED_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_backends(self, {1})
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, {1})
"""
PT_DUMMY_PRETRAINED_CLASS = (
DUMMY_PRETRAINED_CLASS
+ """
def forward(self, *args, **kwargs):
requires_backends(self, {1})
"""
)
TF_DUMMY_PRETRAINED_CLASS = (
DUMMY_PRETRAINED_CLASS
+ """
def call(self, *args, **kwargs):
requires_backends(self, {1})
"""
)
FLAX_DUMMY_PRETRAINED_CLASS = (
DUMMY_PRETRAINED_CLASS
+ """
def __call__(self, *args, **kwargs):
requires_backends(self, {1})
"""
)
DUMMY_CLASS = """
class {0}:
class {0}(metaclass=DummyObject):
_backends = {1}
def __init__(self, *args, **kwargs):
requires_backends(self, {1})
"""
DUMMY_FUNCTION = """
def {0}(*args, **kwargs):
requires_backends({0}, {1})
@@ -126,45 +95,12 @@ def read_init():
def create_dummy_object(name, backend_name):
"""Create the code for the dummy object corresponding to `name`."""
_models = [
"ForCausalLM",
"ForConditionalGeneration",
"ForMaskedLM",
"ForMultipleChoice",
"ForNextSentencePrediction",
"ForObjectDetection",
"ForQuestionAnswering",
"ForSegmentation",
"ForSequenceClassification",
"ForTokenClassification",
"Model",
]
_pretrained = ["Config", "Tokenizer", "Processor"]
if name.isupper():
return DUMMY_CONSTANT.format(name)
elif name.islower():
return DUMMY_FUNCTION.format(name, backend_name)
else:
is_model = False
for part in _models:
if part in name:
is_model = True
break
if is_model:
if name.startswith("TF"):
return TF_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
if name.startswith("Flax"):
return FLAX_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
return PT_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
is_pretrained = False
for part in _pretrained:
if part in name:
is_pretrained = True
break
if is_pretrained:
return DUMMY_PRETRAINED_CLASS.format(name, backend_name)
else:
return DUMMY_CLASS.format(name, backend_name)
return DUMMY_CLASS.format(name, backend_name)
def create_dummy_files():
@@ -176,7 +112,8 @@ def create_dummy_files():
for backend, objects in backend_specific_objects.items():
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
dummy_file += "from ..file_utils import requires_backends\n\n"
dummy_file += "# flake8: noqa\n"
dummy_file += "from ..file_utils import DummyObject, requires_backends\n\n"
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
dummy_files[backend] = dummy_file