[Flax] FlaxAutoModelForSeq2SeqLM (#12228)

* add FlaxAutoModelForSeq2SeqLM
This commit is contained in:
Suraj Patil
2021-06-18 13:20:09 +05:30
committed by GitHub
parent e43e11260f
commit f74655cd9b
5 changed files with 41 additions and 0 deletions

View File

@@ -226,6 +226,13 @@ FlaxAutoModelForMaskedLM
:members: :members:
FlaxAutoModelForSeq2SeqLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForSeq2SeqLM
:members:
FlaxAutoModelForSequenceClassification FlaxAutoModelForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@@ -1514,6 +1514,7 @@ if is_flax_available():
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"FLAX_MODEL_FOR_PRETRAINING_MAPPING", "FLAX_MODEL_FOR_PRETRAINING_MAPPING",
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_MAPPING", "FLAX_MODEL_MAPPING",
@@ -1524,6 +1525,7 @@ if is_flax_available():
"FlaxAutoModelForNextSentencePrediction", "FlaxAutoModelForNextSentencePrediction",
"FlaxAutoModelForPreTraining", "FlaxAutoModelForPreTraining",
"FlaxAutoModelForQuestionAnswering", "FlaxAutoModelForQuestionAnswering",
"FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification", "FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification", "FlaxAutoModelForTokenClassification",
] ]
@@ -2851,6 +2853,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
FLAX_MODEL_FOR_PRETRAINING_MAPPING, FLAX_MODEL_FOR_PRETRAINING_MAPPING,
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_MAPPING, FLAX_MODEL_MAPPING,
@@ -2861,6 +2864,7 @@ if TYPE_CHECKING:
FlaxAutoModelForNextSentencePrediction, FlaxAutoModelForNextSentencePrediction,
FlaxAutoModelForPreTraining, FlaxAutoModelForPreTraining,
FlaxAutoModelForQuestionAnswering, FlaxAutoModelForQuestionAnswering,
FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification, FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification, FlaxAutoModelForTokenClassification,
) )

View File

@@ -92,6 +92,7 @@ if is_flax_available():
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"FLAX_MODEL_FOR_PRETRAINING_MAPPING", "FLAX_MODEL_FOR_PRETRAINING_MAPPING",
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_MAPPING", "FLAX_MODEL_MAPPING",
@@ -103,6 +104,7 @@ if is_flax_available():
"FlaxAutoModelForNextSentencePrediction", "FlaxAutoModelForNextSentencePrediction",
"FlaxAutoModelForPreTraining", "FlaxAutoModelForPreTraining",
"FlaxAutoModelForQuestionAnswering", "FlaxAutoModelForQuestionAnswering",
"FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification", "FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification", "FlaxAutoModelForTokenClassification",
] ]
@@ -178,6 +180,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
FLAX_MODEL_FOR_PRETRAINING_MAPPING, FLAX_MODEL_FOR_PRETRAINING_MAPPING,
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_MAPPING, FLAX_MODEL_MAPPING,
@@ -189,6 +192,7 @@ if TYPE_CHECKING:
FlaxAutoModelForNextSentencePrediction, FlaxAutoModelForNextSentencePrediction,
FlaxAutoModelForPreTraining, FlaxAutoModelForPreTraining,
FlaxAutoModelForQuestionAnswering, FlaxAutoModelForQuestionAnswering,
FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification, FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification, FlaxAutoModelForTokenClassification,
) )

View File

@@ -129,6 +129,13 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
] ]
) )
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(BartConfig, FlaxBartForConditionalGeneration)
]
)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[ [
# Model for Sequence Classification mapping # Model for Sequence Classification mapping
@@ -197,6 +204,13 @@ FlaxAutoModelForMaskedLM = auto_class_factory(
"FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling" "FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
) )
FlaxAutoModelForSeq2SeqLM = auto_class_factory(
"FlaxAutoModelForSeq2SeqLM",
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
head_doc="sequence-to-sequence language modeling",
)
FlaxAutoModelForSequenceClassification = auto_class_factory( FlaxAutoModelForSequenceClassification = auto_class_factory(
"FlaxAutoModelForSequenceClassification", "FlaxAutoModelForSequenceClassification",
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,

View File

@@ -94,6 +94,9 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = None
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
@@ -166,6 +169,15 @@ class FlaxAutoModelForQuestionAnswering:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxAutoModelForSeq2SeqLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxAutoModelForSequenceClassification: class FlaxAutoModelForSequenceClassification:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])