From f74655cd9b2e316af9d862968bc59c15d6849cad Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 18 Jun 2021 13:20:09 +0530 Subject: [PATCH] [Flax] FlaxAutoModelForSeq2SeqLM (#12228) * add FlaxAutoModelForSeq2SeqLM --- docs/source/model_doc/auto.rst | 7 +++++++ src/transformers/__init__.py | 4 ++++ src/transformers/models/auto/__init__.py | 4 ++++ src/transformers/models/auto/modeling_flax_auto.py | 14 ++++++++++++++ src/transformers/utils/dummy_flax_objects.py | 12 ++++++++++++ 5 files changed, 41 insertions(+) diff --git a/docs/source/model_doc/auto.rst b/docs/source/model_doc/auto.rst index 7b8ce142e0..69f67d7f56 100644 --- a/docs/source/model_doc/auto.rst +++ b/docs/source/model_doc/auto.rst @@ -226,6 +226,13 @@ FlaxAutoModelForMaskedLM :members: +FlaxAutoModelForSeq2SeqLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAutoModelForSeq2SeqLM + :members: + + FlaxAutoModelForSequenceClassification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d8a7fc003f..dad079d40e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1514,6 +1514,7 @@ if is_flax_available(): "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "FLAX_MODEL_FOR_PRETRAINING_MAPPING", "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_MAPPING", @@ -1524,6 +1525,7 @@ if is_flax_available(): "FlaxAutoModelForNextSentencePrediction", "FlaxAutoModelForPreTraining", "FlaxAutoModelForQuestionAnswering", + "FlaxAutoModelForSeq2SeqLM", "FlaxAutoModelForSequenceClassification", "FlaxAutoModelForTokenClassification", ] @@ -2851,6 +2853,7 @@ if TYPE_CHECKING: FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, FLAX_MODEL_FOR_PRETRAINING_MAPPING, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, + FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_MAPPING, @@ -2861,6 +2864,7 @@ if TYPE_CHECKING: FlaxAutoModelForNextSentencePrediction, FlaxAutoModelForPreTraining, FlaxAutoModelForQuestionAnswering, + FlaxAutoModelForSeq2SeqLM, FlaxAutoModelForSequenceClassification, FlaxAutoModelForTokenClassification, ) diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 2123889478..d483b271b8 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -92,6 +92,7 @@ if is_flax_available(): "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "FLAX_MODEL_FOR_PRETRAINING_MAPPING", "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_MAPPING", @@ -103,6 +104,7 @@ if is_flax_available(): "FlaxAutoModelForNextSentencePrediction", "FlaxAutoModelForPreTraining", "FlaxAutoModelForQuestionAnswering", + "FlaxAutoModelForSeq2SeqLM", "FlaxAutoModelForSequenceClassification", "FlaxAutoModelForTokenClassification", ] @@ -178,6 +180,7 @@ if TYPE_CHECKING: FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, FLAX_MODEL_FOR_PRETRAINING_MAPPING, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, + FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_MAPPING, @@ -189,6 +192,7 @@ if TYPE_CHECKING: FlaxAutoModelForNextSentencePrediction, FlaxAutoModelForPreTraining, FlaxAutoModelForQuestionAnswering, + FlaxAutoModelForSeq2SeqLM, FlaxAutoModelForSequenceClassification, FlaxAutoModelForTokenClassification, ) diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index ff59d35c62..be03814c3b 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -129,6 +129,13 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( ] ) +FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( + [ + # Model for Seq2Seq Causal LM mapping + (BartConfig, FlaxBartForConditionalGeneration) + ] +) + FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( [ # Model for Sequence Classification mapping @@ -197,6 +204,13 @@ FlaxAutoModelForMaskedLM = auto_class_factory( "FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling" ) + +FlaxAutoModelForSeq2SeqLM = auto_class_factory( + "FlaxAutoModelForSeq2SeqLM", + FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + head_doc="sequence-to-sequence language modeling", +) + FlaxAutoModelForSequenceClassification = auto_class_factory( "FlaxAutoModelForSequenceClassification", FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 7bae4a9a76..7ad7ee76b6 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -94,6 +94,9 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = None FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None +FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None + + FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None @@ -166,6 +169,15 @@ class FlaxAutoModelForQuestionAnswering: requires_backends(cls, ["flax"]) +class FlaxAutoModelForSeq2SeqLM: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxAutoModelForSequenceClassification: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"])