add FlaxAutoModelForImageClassification in main init (#12298)
This commit is contained in:
@@ -266,3 +266,10 @@ FlaxAutoModelForNextSentencePrediction
|
|||||||
|
|
||||||
.. autoclass:: transformers.FlaxAutoModelForNextSentencePrediction
|
.. autoclass:: transformers.FlaxAutoModelForNextSentencePrediction
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
FlaxAutoModelForImageClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxAutoModelForImageClassification
|
||||||
|
:members:
|
||||||
|
|||||||
@@ -1509,6 +1509,7 @@ if is_flax_available():
|
|||||||
_import_structure["models.auto"].extend(
|
_import_structure["models.auto"].extend(
|
||||||
[
|
[
|
||||||
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
|
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||||
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
|
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
|
||||||
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||||
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||||
@@ -1520,6 +1521,7 @@ if is_flax_available():
|
|||||||
"FLAX_MODEL_MAPPING",
|
"FLAX_MODEL_MAPPING",
|
||||||
"FlaxAutoModel",
|
"FlaxAutoModel",
|
||||||
"FlaxAutoModelForCausalLM",
|
"FlaxAutoModelForCausalLM",
|
||||||
|
"FlaxAutoModelForImageClassification",
|
||||||
"FlaxAutoModelForMaskedLM",
|
"FlaxAutoModelForMaskedLM",
|
||||||
"FlaxAutoModelForMultipleChoice",
|
"FlaxAutoModelForMultipleChoice",
|
||||||
"FlaxAutoModelForNextSentencePrediction",
|
"FlaxAutoModelForNextSentencePrediction",
|
||||||
@@ -2848,6 +2850,7 @@ if TYPE_CHECKING:
|
|||||||
from .modeling_flax_utils import FlaxPreTrainedModel
|
from .modeling_flax_utils import FlaxPreTrainedModel
|
||||||
from .models.auto import (
|
from .models.auto import (
|
||||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||||
@@ -2859,6 +2862,7 @@ if TYPE_CHECKING:
|
|||||||
FLAX_MODEL_MAPPING,
|
FLAX_MODEL_MAPPING,
|
||||||
FlaxAutoModel,
|
FlaxAutoModel,
|
||||||
FlaxAutoModelForCausalLM,
|
FlaxAutoModelForCausalLM,
|
||||||
|
FlaxAutoModelForImageClassification,
|
||||||
FlaxAutoModelForMaskedLM,
|
FlaxAutoModelForMaskedLM,
|
||||||
FlaxAutoModelForMultipleChoice,
|
FlaxAutoModelForMultipleChoice,
|
||||||
FlaxAutoModelForNextSentencePrediction,
|
FlaxAutoModelForNextSentencePrediction,
|
||||||
|
|||||||
@@ -87,6 +87,7 @@ if is_tf_available():
|
|||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
_import_structure["modeling_flax_auto"] = [
|
_import_structure["modeling_flax_auto"] = [
|
||||||
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
|
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||||
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
|
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
|
||||||
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||||
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||||
@@ -175,6 +176,7 @@ if TYPE_CHECKING:
|
|||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
from .modeling_flax_auto import (
|
from .modeling_flax_auto import (
|
||||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING = OrderedDict(
|
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Image-classsification
|
# Model for Image-classsification
|
||||||
(ViTConfig, FlaxViTForImageClassification),
|
(ViTConfig, FlaxViTForImageClassification),
|
||||||
@@ -188,7 +188,7 @@ FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING)
|
|||||||
|
|
||||||
FlaxAutoModelForImageClassification = auto_class_factory(
|
FlaxAutoModelForImageClassification = auto_class_factory(
|
||||||
"FlaxAutoModelForImageClassification",
|
"FlaxAutoModelForImageClassification",
|
||||||
FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING,
|
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
head_doc="image classification modeling",
|
head_doc="image classification modeling",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -79,6 +79,9 @@ class FlaxPreTrainedModel:
|
|||||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = None
|
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING = None
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
@@ -124,6 +127,15 @@ class FlaxAutoModelForCausalLM:
|
|||||||
requires_backends(cls, ["flax"])
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxAutoModelForImageClassification:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxAutoModelForMaskedLM:
|
class FlaxAutoModelForMaskedLM:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
|
|||||||
Reference in New Issue
Block a user