Add forward method to dummy models (#14419)

* Add forward method to dummy models

* Fix quality
This commit is contained in:
Sylvain Gugger
2021-11-16 09:24:40 -05:00
committed by GitHub
parent 040fd47162
commit 3e8d17e66d
6 changed files with 2071 additions and 4 deletions

View File

@@ -43,6 +43,30 @@ class {0}:
requires_backends(cls, {1})
"""
PT_DUMMY_PRETRAINED_CLASS = (
DUMMY_PRETRAINED_CLASS
+ """
def forward(self, *args, **kwargs):
requires_backends(self, {1})
"""
)
TF_DUMMY_PRETRAINED_CLASS = (
DUMMY_PRETRAINED_CLASS
+ """
def call(self, *args, **kwargs):
requires_backends(self, {1})
"""
)
FLAX_DUMMY_PRETRAINED_CLASS = (
DUMMY_PRETRAINED_CLASS
+ """
def __call__(self, *args, **kwargs):
requires_backends(self, {1})
"""
)
DUMMY_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
@@ -102,8 +126,7 @@ def read_init():
def create_dummy_object(name, backend_name):
"""Create the code for the dummy object corresponding to `name`."""
_pretrained = [
"Config",
_models = [
"ForCausalLM",
"ForConditionalGeneration",
"ForMaskedLM",
@@ -114,14 +137,24 @@ def create_dummy_object(name, backend_name):
"ForSequenceClassification",
"ForTokenClassification",
"Model",
"Tokenizer",
"Processor",
]
_pretrained = ["Config", "Tokenizer", "Processor"]
if name.isupper():
return DUMMY_CONSTANT.format(name)
elif name.islower():
return DUMMY_FUNCTION.format(name, backend_name)
else:
is_model = False
for part in _models:
if part in name:
is_model = True
break
if is_model:
if name.startswith("TF"):
return TF_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
if name.startswith("Flax"):
return FLAX_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
return PT_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
is_pretrained = False
for part in _pretrained:
if part in name: