add FlaxAutoModelForImageClassification in main init (#12298)

This commit is contained in:
Suraj Patil
2021-06-22 18:26:05 +05:30
committed by GitHub
parent 2affeb2905
commit 1498eb9888
5 changed files with 27 additions and 2 deletions

View File

@@ -266,3 +266,10 @@ FlaxAutoModelForNextSentencePrediction
.. autoclass:: transformers.FlaxAutoModelForNextSentencePrediction
:members:
FlaxAutoModelForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForImageClassification
:members:

View File

@@ -1509,6 +1509,7 @@ if is_flax_available():
_import_structure["models.auto"].extend(
[
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
@@ -1520,6 +1521,7 @@ if is_flax_available():
"FLAX_MODEL_MAPPING",
"FlaxAutoModel",
"FlaxAutoModelForCausalLM",
"FlaxAutoModelForImageClassification",
"FlaxAutoModelForMaskedLM",
"FlaxAutoModelForMultipleChoice",
"FlaxAutoModelForNextSentencePrediction",
@@ -2848,6 +2850,7 @@ if TYPE_CHECKING:
from .modeling_flax_utils import FlaxPreTrainedModel
from .models.auto import (
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
@@ -2859,6 +2862,7 @@ if TYPE_CHECKING:
FLAX_MODEL_MAPPING,
FlaxAutoModel,
FlaxAutoModelForCausalLM,
FlaxAutoModelForImageClassification,
FlaxAutoModelForMaskedLM,
FlaxAutoModelForMultipleChoice,
FlaxAutoModelForNextSentencePrediction,

View File

@@ -87,6 +87,7 @@ if is_tf_available():
if is_flax_available():
_import_structure["modeling_flax_auto"] = [
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
@@ -175,6 +176,7 @@ if TYPE_CHECKING:
if is_flax_available():
from .modeling_flax_auto import (
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,

View File

@@ -115,7 +115,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
]
)
FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING = OrderedDict(
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Image-classsification
(ViTConfig, FlaxViTForImageClassification),
@@ -188,7 +188,7 @@ FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING)
FlaxAutoModelForImageClassification = auto_class_factory(
"FlaxAutoModelForImageClassification",
FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
head_doc="image classification modeling",
)

View File

@@ -79,6 +79,9 @@ class FlaxPreTrainedModel:
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = None
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
FLAX_MODEL_FOR_MASKED_LM_MAPPING = None
@@ -124,6 +127,15 @@ class FlaxAutoModelForCausalLM:
requires_backends(cls, ["flax"])
class FlaxAutoModelForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxAutoModelForMaskedLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])