diff --git a/docs/source/en/model_doc/deberta-v2.md b/docs/source/en/model_doc/deberta-v2.md index 38f575877e..8dec57a171 100644 --- a/docs/source/en/model_doc/deberta-v2.md +++ b/docs/source/en/model_doc/deberta-v2.md @@ -152,3 +152,8 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code [[autodoc]] TFDebertaV2ForQuestionAnswering - call + +## TFDebertaV2ForMultipleChoice + +[[autodoc]] TFDebertaV2ForMultipleChoice + - call diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5a8e9f93ed..fbb12fc45f 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3360,6 +3360,7 @@ else: [ "TF_DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST", "TFDebertaV2ForMaskedLM", + "TFDebertaV2ForMultipleChoice", "TFDebertaV2ForQuestionAnswering", "TFDebertaV2ForSequenceClassification", "TFDebertaV2ForTokenClassification", @@ -6969,6 +6970,7 @@ if TYPE_CHECKING: from .models.deberta_v2 import ( TF_DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST, TFDebertaV2ForMaskedLM, + TFDebertaV2ForMultipleChoice, TFDebertaV2ForQuestionAnswering, TFDebertaV2ForSequenceClassification, TFDebertaV2ForTokenClassification, diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index ecf9b06da5..d2068a4511 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -409,6 +409,7 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( ("bert", "TFBertForMultipleChoice"), ("camembert", "TFCamembertForMultipleChoice"), ("convbert", "TFConvBertForMultipleChoice"), + ("deberta-v2", "TFDebertaV2ForMultipleChoice"), ("distilbert", "TFDistilBertForMultipleChoice"), ("electra", "TFElectraForMultipleChoice"), ("flaubert", "TFFlaubertForMultipleChoice"), diff --git a/src/transformers/models/deberta_v2/__init__.py b/src/transformers/models/deberta_v2/__init__.py index 2c5bf57572..fb1b20a331 100644 --- a/src/transformers/models/deberta_v2/__init__.py +++ b/src/transformers/models/deberta_v2/__init__.py @@ -46,6 +46,7 @@ else: "TF_DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST", "TFDebertaV2ForMaskedLM", "TFDebertaV2ForQuestionAnswering", + "TFDebertaV2ForMultipleChoice", "TFDebertaV2ForSequenceClassification", "TFDebertaV2ForTokenClassification", "TFDebertaV2Model", @@ -95,6 +96,7 @@ if TYPE_CHECKING: from .modeling_tf_deberta_v2 import ( TF_DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST, TFDebertaV2ForMaskedLM, + TFDebertaV2ForMultipleChoice, TFDebertaV2ForQuestionAnswering, TFDebertaV2ForSequenceClassification, TFDebertaV2ForTokenClassification, diff --git a/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py index e53d86dcb8..fa2cf1df74 100644 --- a/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py @@ -14,7 +14,6 @@ # limitations under the License. """ TF 2.0 DeBERTa-v2 model.""" - from __future__ import annotations from typing import Dict, Optional, Tuple, Union @@ -26,6 +25,7 @@ from ...activations_tf import get_tf_activation from ...modeling_tf_outputs import ( TFBaseModelOutput, TFMaskedLMOutput, + TFMultipleChoiceModelOutput, TFQuestionAnsweringModelOutput, TFSequenceClassifierOutput, TFTokenClassifierOutput, @@ -33,6 +33,7 @@ from ...modeling_tf_outputs import ( from ...modeling_tf_utils import ( TFMaskedLanguageModelingLoss, TFModelInputType, + TFMultipleChoiceLoss, TFPreTrainedModel, TFQuestionAnsweringLoss, TFSequenceClassificationLoss, @@ -47,7 +48,6 @@ from .configuration_deberta_v2 import DebertaV2Config logger = logging.get_logger(__name__) - _CONFIG_FOR_DOC = "DebertaV2Config" _CHECKPOINT_FOR_DOC = "kamalkraj/deberta-v2-xlarge" @@ -1529,3 +1529,102 @@ class TFDebertaV2ForQuestionAnswering(TFDebertaV2PreTrainedModel, TFQuestionAnsw hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings( + """ + DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + DEBERTA_START_DOCSTRING, +) +class TFDebertaV2ForMultipleChoice(TFDebertaV2PreTrainedModel, TFMultipleChoiceLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + # _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"] + # _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config: DebertaV2Config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.deberta = TFDebertaV2MainLayer(config, name="deberta") + self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.pooler = TFDebertaV2ContextPooler(config, name="pooler") + self.classifier = tf.keras.layers.Dense( + units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + + @unpack_inputs + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) + """ + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None + flat_attention_mask = ( + tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None + ) + flat_token_type_ids = ( + tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None + ) + flat_position_ids = ( + tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None + ) + flat_inputs_embeds = ( + tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None + else None + ) + outputs = self.deberta( + input_ids=flat_input_ids, + attention_mask=flat_attention_mask, + token_type_ids=flat_token_type_ids, + position_ids=flat_position_ids, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + pooled_output = self.pooler(sequence_output, training=training) + pooled_output = self.dropout(pooled_output, training=training) + logits = self.classifier(pooled_output) + reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 9b1aae4493..972ab49c0f 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -974,6 +974,13 @@ class TFDebertaV2ForMaskedLM(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFDebertaV2ForMultipleChoice(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFDebertaV2ForQuestionAnswering(metaclass=DummyObject): _backends = ["tf"] diff --git a/tests/models/deberta_v2/test_modeling_tf_deberta_v2.py b/tests/models/deberta_v2/test_modeling_tf_deberta_v2.py index 8b9bcc15ea..b46f68525d 100644 --- a/tests/models/deberta_v2/test_modeling_tf_deberta_v2.py +++ b/tests/models/deberta_v2/test_modeling_tf_deberta_v2.py @@ -31,6 +31,7 @@ if is_tf_available(): from transformers import ( TFDebertaV2ForMaskedLM, + TFDebertaV2ForMultipleChoice, TFDebertaV2ForQuestionAnswering, TFDebertaV2ForSequenceClassification, TFDebertaV2ForTokenClassification, @@ -196,6 +197,22 @@ class TFDebertaV2ModelTester: self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) + def create_and_check_for_multiple_choice( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_choices = self.num_choices + model = TFDebertaV2ForMultipleChoice(config=config) + multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1)) + multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1)) + multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1)) + inputs = { + "input_ids": multiple_choice_inputs_ids, + "attention_mask": multiple_choice_input_mask, + "token_type_ids": multiple_choice_token_type_ids, + } + result = model(inputs) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -218,6 +235,7 @@ class TFDebertaModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC TFDebertaV2Model, TFDebertaV2ForMaskedLM, TFDebertaV2ForQuestionAnswering, + TFDebertaV2ForMultipleChoice, TFDebertaV2ForSequenceClassification, TFDebertaV2ForTokenClassification, )