From 1a92bc578866bc0eee1104d253dbad98b89ecc3b Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Sun, 21 Nov 2021 17:39:20 -0500 Subject: [PATCH] Fix dummy objects for quantization (#14478) * Fix dummy objects for quantization * Add more models --- src/transformers/utils/dummy_flax_objects.py | 7 ++++ src/transformers/utils/dummy_pt_objects.py | 35 +++++++++++++++++++ ..._pytorch_quantization_and_torch_objects.py | 31 ++++++++++++++++ src/transformers/utils/dummy_tf_objects.py | 14 ++++++++ utils/check_dummies.py | 3 +- 5 files changed, 89 insertions(+), 1 deletion(-) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index cf54fab811..8060e731d7 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -477,6 +477,13 @@ class FlaxBertForNextSentencePrediction: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + def __call__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + class FlaxBertForPreTraining: def __init__(self, *args, **kwargs): diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 0ed3cf3dbd..fa8bb6d04c 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -56,6 +56,13 @@ class TextDatasetForNextSentencePrediction: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["torch"]) + class BeamScorer: def __init__(self, *args, **kwargs): @@ -783,6 +790,13 @@ class BertForNextSentencePrediction: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["torch"]) + class BertForPreTraining: def __init__(self, *args, **kwargs): @@ -2106,6 +2120,13 @@ class FNetForNextSentencePrediction: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["torch"]) + class FNetForPreTraining: def __init__(self, *args, **kwargs): @@ -3254,6 +3275,13 @@ class MegatronBertForNextSentencePrediction: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["torch"]) + class MegatronBertForPreTraining: def __init__(self, *args, **kwargs): @@ -3373,6 +3401,13 @@ class MobileBertForNextSentencePrediction: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["torch"]) + class MobileBertForPreTraining: def __init__(self, *args, **kwargs): diff --git a/src/transformers/utils/dummy_pytorch_quantization_and_torch_objects.py b/src/transformers/utils/dummy_pytorch_quantization_and_torch_objects.py index 79f9c54316..9f036f9e1a 100644 --- a/src/transformers/utils/dummy_pytorch_quantization_and_torch_objects.py +++ b/src/transformers/utils/dummy_pytorch_quantization_and_torch_objects.py @@ -13,6 +13,9 @@ class QDQBertForMaskedLM: def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["pytorch_quantization", "torch"]) + def forward(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + class QDQBertForMultipleChoice: def __init__(self, *args, **kwargs): @@ -22,11 +25,21 @@ class QDQBertForMultipleChoice: def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["pytorch_quantization", "torch"]) + def forward(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + class QDQBertForNextSentencePrediction: def __init__(self, *args, **kwargs): requires_backends(self, ["pytorch_quantization", "torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["pytorch_quantization", "torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + class QDQBertForQuestionAnswering: def __init__(self, *args, **kwargs): @@ -36,6 +49,9 @@ class QDQBertForQuestionAnswering: def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["pytorch_quantization", "torch"]) + def forward(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + class QDQBertForSequenceClassification: def __init__(self, *args, **kwargs): @@ -45,6 +61,9 @@ class QDQBertForSequenceClassification: def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["pytorch_quantization", "torch"]) + def forward(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + class QDQBertForTokenClassification: def __init__(self, *args, **kwargs): @@ -54,6 +73,9 @@ class QDQBertForTokenClassification: def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["pytorch_quantization", "torch"]) + def forward(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + class QDQBertLayer: def __init__(self, *args, **kwargs): @@ -68,6 +90,9 @@ class QDQBertLMHeadModel: def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["pytorch_quantization", "torch"]) + def forward(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + class QDQBertModel: def __init__(self, *args, **kwargs): @@ -77,6 +102,9 @@ class QDQBertModel: def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["pytorch_quantization", "torch"]) + def forward(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + class QDQBertPreTrainedModel: def __init__(self, *args, **kwargs): @@ -86,6 +114,9 @@ class QDQBertPreTrainedModel: def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["pytorch_quantization", "torch"]) + def forward(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + def load_tf_weights_in_qdqbert(*args, **kwargs): requires_backends(load_tf_weights_in_qdqbert, ["pytorch_quantization", "torch"]) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index f92ee0ad9d..14991e8b6a 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -452,6 +452,13 @@ class TFBertForNextSentencePrediction: def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + def call(self, *args, **kwargs): + requires_backends(self, ["tf"]) + class TFBertForPreTraining: def __init__(self, *args, **kwargs): @@ -1774,6 +1781,13 @@ class TFMobileBertForNextSentencePrediction: def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + def call(self, *args, **kwargs): + requires_backends(self, ["tf"]) + class TFMobileBertForPreTraining: def __init__(self, *args, **kwargs): diff --git a/utils/check_dummies.py b/utils/check_dummies.py index 336f38d75a..2084b32f13 100644 --- a/utils/check_dummies.py +++ b/utils/check_dummies.py @@ -23,7 +23,7 @@ import re PATH_TO_TRANSFORMERS = "src/transformers" # Matches is_xxx_available() -_re_backend = re.compile(r"is\_([a-z]*)_available()") +_re_backend = re.compile(r"is\_([a-z_]*)_available()") # Matches from xxx import bla _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") _re_test_backend = re.compile(r"^\s+if\s+is\_[a-z]*\_available\(\)") @@ -131,6 +131,7 @@ def create_dummy_object(name, backend_name): "ForConditionalGeneration", "ForMaskedLM", "ForMultipleChoice", + "ForNextSentencePrediction", "ForObjectDetection", "ForQuestionAnswering", "ForSegmentation",