From 1498eb9888d55d76385b45e074f26703cc5049f3 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 22 Jun 2021 18:26:05 +0530 Subject: [PATCH] add FlaxAutoModelForImageClassification in main init (#12298) --- docs/source/model_doc/auto.rst | 7 +++++++ src/transformers/__init__.py | 4 ++++ src/transformers/models/auto/__init__.py | 2 ++ src/transformers/models/auto/modeling_flax_auto.py | 4 ++-- src/transformers/utils/dummy_flax_objects.py | 12 ++++++++++++ 5 files changed, 27 insertions(+), 2 deletions(-) diff --git a/docs/source/model_doc/auto.rst b/docs/source/model_doc/auto.rst index 69f67d7f56..7ccfbdf87d 100644 --- a/docs/source/model_doc/auto.rst +++ b/docs/source/model_doc/auto.rst @@ -266,3 +266,10 @@ FlaxAutoModelForNextSentencePrediction .. autoclass:: transformers.FlaxAutoModelForNextSentencePrediction :members: + + +FlaxAutoModelForImageClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAutoModelForImageClassification + :members: diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index dad079d40e..0d70222780 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1509,6 +1509,7 @@ if is_flax_available(): _import_structure["models.auto"].extend( [ "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING", + "FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_MASKED_LM_MAPPING", "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", @@ -1520,6 +1521,7 @@ if is_flax_available(): "FLAX_MODEL_MAPPING", "FlaxAutoModel", "FlaxAutoModelForCausalLM", + "FlaxAutoModelForImageClassification", "FlaxAutoModelForMaskedLM", "FlaxAutoModelForMultipleChoice", "FlaxAutoModelForNextSentencePrediction", @@ -2848,6 +2850,7 @@ if TYPE_CHECKING: from .modeling_flax_utils import FlaxPreTrainedModel from .models.auto import ( FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, + FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, @@ -2859,6 +2862,7 @@ if TYPE_CHECKING: FLAX_MODEL_MAPPING, FlaxAutoModel, FlaxAutoModelForCausalLM, + FlaxAutoModelForImageClassification, FlaxAutoModelForMaskedLM, FlaxAutoModelForMultipleChoice, FlaxAutoModelForNextSentencePrediction, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index d483b271b8..f0e16ca27d 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -87,6 +87,7 @@ if is_tf_available(): if is_flax_available(): _import_structure["modeling_flax_auto"] = [ "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING", + "FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_MASKED_LM_MAPPING", "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", @@ -175,6 +176,7 @@ if TYPE_CHECKING: if is_flax_available(): from .modeling_flax_auto import ( FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, + FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index be03814c3b..dd3d3cd809 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -115,7 +115,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( ] ) -FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING = OrderedDict( +FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict( [ # Model for Image-classsification (ViTConfig, FlaxViTForImageClassification), @@ -188,7 +188,7 @@ FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING) FlaxAutoModelForImageClassification = auto_class_factory( "FlaxAutoModelForImageClassification", - FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING, + FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, head_doc="image classification modeling", ) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 7ad7ee76b6..0eea12143b 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -79,6 +79,9 @@ class FlaxPreTrainedModel: FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = None +FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None + + FLAX_MODEL_FOR_MASKED_LM_MAPPING = None @@ -124,6 +127,15 @@ class FlaxAutoModelForCausalLM: requires_backends(cls, ["flax"]) +class FlaxAutoModelForImageClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxAutoModelForMaskedLM: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"])