From d07b540a37afbe790a093183b3bb105c5ff5f7ec Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Tue, 15 Jun 2021 14:39:05 +0200 Subject: [PATCH] Have dummy processors have a `from_pretrained` method (#12145) --- src/transformers/utils/dummy_flax_objects.py | 8 ++++ src/transformers/utils/dummy_pt_objects.py | 44 +++++++++++++++++++ .../dummy_sentencepiece_and_speech_objects.py | 4 ++ .../utils/dummy_vision_objects.py | 4 ++ utils/check_dummies.py | 1 + 5 files changed, 61 insertions(+) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 5bc72929b4..f4cbcb2496 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -6,11 +6,19 @@ class FlaxLogitsProcessor: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + class FlaxLogitsProcessorList: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + class FlaxLogitsWarper: def __init__(self, *args, **kwargs): diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 1fa3f30cf5..0a995c29cb 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -127,31 +127,55 @@ class ForcedBOSTokenLogitsProcessor: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class ForcedEOSTokenLogitsProcessor: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class HammingDiversityLogitsProcessor: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class InfNanRemoveLogitsProcessor: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class LogitsProcessor: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class LogitsProcessorList: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class LogitsWarper: def __init__(self, *args, **kwargs): @@ -162,26 +186,46 @@ class MinLengthLogitsProcessor: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class NoBadWordsLogitsProcessor: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class NoRepeatNGramLogitsProcessor: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class PrefixConstrainedLogitsProcessor: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class RepetitionPenaltyLogitsProcessor: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class TemperatureLogitsWarper: def __init__(self, *args, **kwargs): diff --git a/src/transformers/utils/dummy_sentencepiece_and_speech_objects.py b/src/transformers/utils/dummy_sentencepiece_and_speech_objects.py index b030ce604a..42727619d9 100644 --- a/src/transformers/utils/dummy_sentencepiece_and_speech_objects.py +++ b/src/transformers/utils/dummy_sentencepiece_and_speech_objects.py @@ -5,3 +5,7 @@ from ..file_utils import requires_backends class Speech2TextProcessor: def __init__(self, *args, **kwargs): requires_backends(self, ["sentencepiece", "speech"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["sentencepiece", "speech"]) diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 84b37d35df..b03bc23253 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -16,6 +16,10 @@ class CLIPProcessor: def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["vision"]) + class DeiTFeatureExtractor: def __init__(self, *args, **kwargs): diff --git a/utils/check_dummies.py b/utils/check_dummies.py index bd990abac0..0c98908968 100644 --- a/utils/check_dummies.py +++ b/utils/check_dummies.py @@ -115,6 +115,7 @@ def create_dummy_object(name, backend_name): "ForTokenClassification", "Model", "Tokenizer", + "Processor", ] if name.isupper(): return DUMMY_CONSTANT.format(name)