[AutoModel] Add AutoModelForTextEncoding (#24305)
* [AutoModel] Add AutoModelForTextEncoding * add mt5 * add other models * add to docs * fix tf imports * add tf to docs / init * up * fix inits * add to dummy objects
This commit is contained in:
@@ -214,6 +214,14 @@ The following auto classes are available for the following natural language proc
|
||||
|
||||
[[autodoc]] FlaxAutoModelForQuestionAnswering
|
||||
|
||||
### AutoModelForTextEncoding
|
||||
|
||||
[[autodoc]] AutoModelForTextEncoding
|
||||
|
||||
### TFAutoModelForTextEncoding
|
||||
|
||||
[[autodoc]] TFAutoModelForTextEncoding
|
||||
|
||||
## Computer vision
|
||||
|
||||
The following auto classes are available for the following computer vision tasks.
|
||||
|
||||
@@ -1053,6 +1053,7 @@ else:
|
||||
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
||||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_FOR_TEXT_ENCODING_MAPPING",
|
||||
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
|
||||
@@ -1087,6 +1088,7 @@ else:
|
||||
"AutoModelForSequenceClassification",
|
||||
"AutoModelForSpeechSeq2Seq",
|
||||
"AutoModelForTableQuestionAnswering",
|
||||
"AutoModelForTextEncoding",
|
||||
"AutoModelForTokenClassification",
|
||||
"AutoModelForUniversalSegmentation",
|
||||
"AutoModelForVideoClassification",
|
||||
@@ -2984,6 +2986,7 @@ else:
|
||||
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
||||
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||
"TF_MODEL_FOR_TEXT_ENCODING_MAPPING",
|
||||
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
|
||||
@@ -3003,6 +3006,7 @@ else:
|
||||
"TFAutoModelForSequenceClassification",
|
||||
"TFAutoModelForSpeechSeq2Seq",
|
||||
"TFAutoModelForTableQuestionAnswering",
|
||||
"TFAutoModelForTextEncoding",
|
||||
"TFAutoModelForTokenClassification",
|
||||
"TFAutoModelForVision2Seq",
|
||||
"TFAutoModelForZeroShotImageClassification",
|
||||
@@ -4807,6 +4811,7 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_TEXT_ENCODING_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
||||
@@ -4841,6 +4846,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoModelForTableQuestionAnswering,
|
||||
AutoModelForTextEncoding,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForUniversalSegmentation,
|
||||
AutoModelForVideoClassification,
|
||||
@@ -6374,6 +6380,7 @@ if TYPE_CHECKING:
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
TF_MODEL_FOR_TEXT_ENCODING_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
|
||||
@@ -6393,6 +6400,7 @@ if TYPE_CHECKING:
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFAutoModelForSpeechSeq2Seq,
|
||||
TFAutoModelForTableQuestionAnswering,
|
||||
TFAutoModelForTextEncoding,
|
||||
TFAutoModelForTokenClassification,
|
||||
TFAutoModelForVision2Seq,
|
||||
TFAutoModelForZeroShotImageClassification,
|
||||
|
||||
@@ -64,6 +64,7 @@ else:
|
||||
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
||||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_FOR_TEXT_ENCODING_MAPPING",
|
||||
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
|
||||
@@ -85,6 +86,7 @@ else:
|
||||
"AutoModelForImageSegmentation",
|
||||
"AutoModelForInstanceSegmentation",
|
||||
"AutoModelForMaskGeneration",
|
||||
"AutoModelForTextEncoding",
|
||||
"AutoModelForMaskedImageModeling",
|
||||
"AutoModelForMaskedLM",
|
||||
"AutoModelForMultipleChoice",
|
||||
@@ -131,6 +133,7 @@ else:
|
||||
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
||||
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||
"TF_MODEL_FOR_TEXT_ENCODING_MAPPING",
|
||||
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
|
||||
@@ -150,6 +153,7 @@ else:
|
||||
"TFAutoModelForSequenceClassification",
|
||||
"TFAutoModelForSpeechSeq2Seq",
|
||||
"TFAutoModelForTableQuestionAnswering",
|
||||
"TFAutoModelForTextEncoding",
|
||||
"TFAutoModelForTokenClassification",
|
||||
"TFAutoModelForVision2Seq",
|
||||
"TFAutoModelForZeroShotImageClassification",
|
||||
@@ -233,6 +237,7 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_TEXT_ENCODING_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
||||
@@ -267,6 +272,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoModelForTableQuestionAnswering,
|
||||
AutoModelForTextEncoding,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForUniversalSegmentation,
|
||||
AutoModelForVideoClassification,
|
||||
@@ -300,6 +306,7 @@ if TYPE_CHECKING:
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
TF_MODEL_FOR_TEXT_ENCODING_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
|
||||
@@ -319,6 +326,7 @@ if TYPE_CHECKING:
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFAutoModelForSpeechSeq2Seq,
|
||||
TFAutoModelForTableQuestionAnswering,
|
||||
TFAutoModelForTextEncoding,
|
||||
TFAutoModelForTokenClassification,
|
||||
TFAutoModelForVision2Seq,
|
||||
TFAutoModelForZeroShotImageClassification,
|
||||
|
||||
@@ -1011,6 +1011,36 @@ MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("albert", "AlbertModel"),
|
||||
("bert", "BertModel"),
|
||||
("big_bird", "BigBirdModel"),
|
||||
("data2vec-text", "Data2VecTextModel"),
|
||||
("deberta", "DebertaModel"),
|
||||
("deberta-v2", "DebertaV2Model"),
|
||||
("distilbert", "DistilBertModel"),
|
||||
("electra", "ElectraModel"),
|
||||
("flaubert", "FlaubertModel"),
|
||||
("ibert", "IBertModel"),
|
||||
("longformer", "LongformerModel"),
|
||||
("mobilebert", "MobileBertModel"),
|
||||
("mt5", "MT5EncoderModel"),
|
||||
("nystromformer", "NystromformerModel"),
|
||||
("reformer", "ReformerModel"),
|
||||
("rembert", "RemBertModel"),
|
||||
("roberta", "RobertaModel"),
|
||||
("roberta-prelayernorm", "RobertaPreLayerNormModel"),
|
||||
("roc_bert", "RoCBertModel"),
|
||||
("roformer", "RoFormerModel"),
|
||||
("squeezebert", "SqueezeBertModel"),
|
||||
("t5", "T5EncoderModel"),
|
||||
("xlm", "XLMModel"),
|
||||
("xlm-roberta", "XLMRobertaModel"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLModel"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
|
||||
@@ -1088,11 +1118,17 @@ MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BA
|
||||
|
||||
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
|
||||
|
||||
MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
|
||||
|
||||
|
||||
class AutoModelForMaskGeneration(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
|
||||
|
||||
|
||||
class AutoModelForTextEncoding(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
|
||||
|
||||
|
||||
class AutoModel(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_MAPPING
|
||||
|
||||
|
||||
@@ -437,6 +437,28 @@ TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
|
||||
("sam", "TFSamModel"),
|
||||
]
|
||||
)
|
||||
TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("albert", "TFAlbertModel"),
|
||||
("bert", "TFBertModel"),
|
||||
("convbert", "TFConvBertModel"),
|
||||
("deberta", "TFDebertaModel"),
|
||||
("deberta-v2", "TFDebertaV2Model"),
|
||||
("distilbert", "TFDistilBertModel"),
|
||||
("electra", "TFElectraModel"),
|
||||
("flaubert", "TFFlaubertModel"),
|
||||
("longformer", "TFLongformerModel"),
|
||||
("mobilebert", "TFMobileBertModel"),
|
||||
("mt5", "TFMT5EncoderModel"),
|
||||
("rembert", "TFRemBertModel"),
|
||||
("roberta", "TFRobertaModel"),
|
||||
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
|
||||
("roformer", "TFRoFormerModel"),
|
||||
("t5", "TFT5EncoderModel"),
|
||||
("xlm", "TFXLMModel"),
|
||||
("xlm-roberta", "TFXLMRobertaModel"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
|
||||
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||
@@ -491,11 +513,17 @@ TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
|
||||
|
||||
|
||||
class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
|
||||
|
||||
|
||||
class TFAutoModelForTextEncoding(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING
|
||||
|
||||
|
||||
class TFAutoModel(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_MAPPING
|
||||
|
||||
|
||||
@@ -524,6 +524,9 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = None
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_TEXT_ENCODING_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
@@ -726,6 +729,13 @@ class AutoModelForTableQuestionAnswering(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForTextEncoding(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -264,6 +264,9 @@ TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = None
|
||||
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
|
||||
|
||||
|
||||
TF_MODEL_FOR_TEXT_ENCODING_MAPPING = None
|
||||
|
||||
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
@@ -377,6 +380,13 @@ class TFAutoModelForTableQuestionAnswering(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFAutoModelForTextEncoding(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFAutoModelForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user