[Flax] FlaxAutoModelForSeq2SeqLM (#12228)
* add FlaxAutoModelForSeq2SeqLM
This commit is contained in:
@@ -226,6 +226,13 @@ FlaxAutoModelForMaskedLM
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
FlaxAutoModelForSeq2SeqLM
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxAutoModelForSeq2SeqLM
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
FlaxAutoModelForSequenceClassification
|
FlaxAutoModelForSequenceClassification
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -1514,6 +1514,7 @@ if is_flax_available():
|
|||||||
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||||
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
|
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
|
||||||
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
|
||||||
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||||
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||||
"FLAX_MODEL_MAPPING",
|
"FLAX_MODEL_MAPPING",
|
||||||
@@ -1524,6 +1525,7 @@ if is_flax_available():
|
|||||||
"FlaxAutoModelForNextSentencePrediction",
|
"FlaxAutoModelForNextSentencePrediction",
|
||||||
"FlaxAutoModelForPreTraining",
|
"FlaxAutoModelForPreTraining",
|
||||||
"FlaxAutoModelForQuestionAnswering",
|
"FlaxAutoModelForQuestionAnswering",
|
||||||
|
"FlaxAutoModelForSeq2SeqLM",
|
||||||
"FlaxAutoModelForSequenceClassification",
|
"FlaxAutoModelForSequenceClassification",
|
||||||
"FlaxAutoModelForTokenClassification",
|
"FlaxAutoModelForTokenClassification",
|
||||||
]
|
]
|
||||||
@@ -2851,6 +2853,7 @@ if TYPE_CHECKING:
|
|||||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||||
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
|
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
FLAX_MODEL_MAPPING,
|
FLAX_MODEL_MAPPING,
|
||||||
@@ -2861,6 +2864,7 @@ if TYPE_CHECKING:
|
|||||||
FlaxAutoModelForNextSentencePrediction,
|
FlaxAutoModelForNextSentencePrediction,
|
||||||
FlaxAutoModelForPreTraining,
|
FlaxAutoModelForPreTraining,
|
||||||
FlaxAutoModelForQuestionAnswering,
|
FlaxAutoModelForQuestionAnswering,
|
||||||
|
FlaxAutoModelForSeq2SeqLM,
|
||||||
FlaxAutoModelForSequenceClassification,
|
FlaxAutoModelForSequenceClassification,
|
||||||
FlaxAutoModelForTokenClassification,
|
FlaxAutoModelForTokenClassification,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -92,6 +92,7 @@ if is_flax_available():
|
|||||||
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||||
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
|
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
|
||||||
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
||||||
|
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
|
||||||
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||||
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||||
"FLAX_MODEL_MAPPING",
|
"FLAX_MODEL_MAPPING",
|
||||||
@@ -103,6 +104,7 @@ if is_flax_available():
|
|||||||
"FlaxAutoModelForNextSentencePrediction",
|
"FlaxAutoModelForNextSentencePrediction",
|
||||||
"FlaxAutoModelForPreTraining",
|
"FlaxAutoModelForPreTraining",
|
||||||
"FlaxAutoModelForQuestionAnswering",
|
"FlaxAutoModelForQuestionAnswering",
|
||||||
|
"FlaxAutoModelForSeq2SeqLM",
|
||||||
"FlaxAutoModelForSequenceClassification",
|
"FlaxAutoModelForSequenceClassification",
|
||||||
"FlaxAutoModelForTokenClassification",
|
"FlaxAutoModelForTokenClassification",
|
||||||
]
|
]
|
||||||
@@ -178,6 +180,7 @@ if TYPE_CHECKING:
|
|||||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||||
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
|
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
FLAX_MODEL_MAPPING,
|
FLAX_MODEL_MAPPING,
|
||||||
@@ -189,6 +192,7 @@ if TYPE_CHECKING:
|
|||||||
FlaxAutoModelForNextSentencePrediction,
|
FlaxAutoModelForNextSentencePrediction,
|
||||||
FlaxAutoModelForPreTraining,
|
FlaxAutoModelForPreTraining,
|
||||||
FlaxAutoModelForQuestionAnswering,
|
FlaxAutoModelForQuestionAnswering,
|
||||||
|
FlaxAutoModelForSeq2SeqLM,
|
||||||
FlaxAutoModelForSequenceClassification,
|
FlaxAutoModelForSequenceClassification,
|
||||||
FlaxAutoModelForTokenClassification,
|
FlaxAutoModelForTokenClassification,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -129,6 +129,13 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||||
|
[
|
||||||
|
# Model for Seq2Seq Causal LM mapping
|
||||||
|
(BartConfig, FlaxBartForConditionalGeneration)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Sequence Classification mapping
|
# Model for Sequence Classification mapping
|
||||||
@@ -197,6 +204,13 @@ FlaxAutoModelForMaskedLM = auto_class_factory(
|
|||||||
"FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
|
"FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FlaxAutoModelForSeq2SeqLM = auto_class_factory(
|
||||||
|
"FlaxAutoModelForSeq2SeqLM",
|
||||||
|
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
|
head_doc="sequence-to-sequence language modeling",
|
||||||
|
)
|
||||||
|
|
||||||
FlaxAutoModelForSequenceClassification = auto_class_factory(
|
FlaxAutoModelForSequenceClassification = auto_class_factory(
|
||||||
"FlaxAutoModelForSequenceClassification",
|
"FlaxAutoModelForSequenceClassification",
|
||||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
|||||||
@@ -94,6 +94,9 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = None
|
|||||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None
|
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
|
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
@@ -166,6 +169,15 @@ class FlaxAutoModelForQuestionAnswering:
|
|||||||
requires_backends(cls, ["flax"])
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxAutoModelForSeq2SeqLM:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxAutoModelForSequenceClassification:
|
class FlaxAutoModelForSequenceClassification:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
|
|||||||
Reference in New Issue
Block a user