Add forward method to dummy models (#14419)
* Add forward method to dummy models * Fix quality
This commit is contained in:
@@ -43,6 +43,30 @@ class {0}:
|
||||
requires_backends(cls, {1})
|
||||
"""
|
||||
|
||||
PT_DUMMY_PRETRAINED_CLASS = (
|
||||
DUMMY_PRETRAINED_CLASS
|
||||
+ """
|
||||
def forward(self, *args, **kwargs):
|
||||
requires_backends(self, {1})
|
||||
"""
|
||||
)
|
||||
|
||||
TF_DUMMY_PRETRAINED_CLASS = (
|
||||
DUMMY_PRETRAINED_CLASS
|
||||
+ """
|
||||
def call(self, *args, **kwargs):
|
||||
requires_backends(self, {1})
|
||||
"""
|
||||
)
|
||||
|
||||
FLAX_DUMMY_PRETRAINED_CLASS = (
|
||||
DUMMY_PRETRAINED_CLASS
|
||||
+ """
|
||||
def __call__(self, *args, **kwargs):
|
||||
requires_backends(self, {1})
|
||||
"""
|
||||
)
|
||||
|
||||
DUMMY_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -102,8 +126,7 @@ def read_init():
|
||||
|
||||
def create_dummy_object(name, backend_name):
|
||||
"""Create the code for the dummy object corresponding to `name`."""
|
||||
_pretrained = [
|
||||
"Config",
|
||||
_models = [
|
||||
"ForCausalLM",
|
||||
"ForConditionalGeneration",
|
||||
"ForMaskedLM",
|
||||
@@ -114,14 +137,24 @@ def create_dummy_object(name, backend_name):
|
||||
"ForSequenceClassification",
|
||||
"ForTokenClassification",
|
||||
"Model",
|
||||
"Tokenizer",
|
||||
"Processor",
|
||||
]
|
||||
_pretrained = ["Config", "Tokenizer", "Processor"]
|
||||
if name.isupper():
|
||||
return DUMMY_CONSTANT.format(name)
|
||||
elif name.islower():
|
||||
return DUMMY_FUNCTION.format(name, backend_name)
|
||||
else:
|
||||
is_model = False
|
||||
for part in _models:
|
||||
if part in name:
|
||||
is_model = True
|
||||
break
|
||||
if is_model:
|
||||
if name.startswith("TF"):
|
||||
return TF_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
|
||||
if name.startswith("Flax"):
|
||||
return FLAX_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
|
||||
return PT_DUMMY_PRETRAINED_CLASS.format(name, backend_name)
|
||||
is_pretrained = False
|
||||
for part in _pretrained:
|
||||
if part in name:
|
||||
|
||||
Reference in New Issue
Block a user