diff --git a/docs/source/en/model_doc/modernbert.md b/docs/source/en/model_doc/modernbert.md index b2a57effb2..8c939adce0 100644 --- a/docs/source/en/model_doc/modernbert.md +++ b/docs/source/en/model_doc/modernbert.md @@ -115,6 +115,11 @@ echo -e "Plants create [MASK] through a process known as photosynthesis." | tran [[autodoc]] ModernBertForTokenClassification - forward +## ModernBertForMultipleChoice + +[[autodoc]] ModernBertForMultipleChoice + - forward + ## ModernBertForQuestionAnswering [[autodoc]] ModernBertForQuestionAnswering diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 2fbdb36b39..84672e80a7 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1469,6 +1469,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( ("mega", "MegaForMultipleChoice"), ("megatron-bert", "MegatronBertForMultipleChoice"), ("mobilebert", "MobileBertForMultipleChoice"), + ("modernbert", "ModernBertForMultipleChoice"), ("mpnet", "MPNetForMultipleChoice"), ("mra", "MraForMultipleChoice"), ("nezha", "NezhaForMultipleChoice"), diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 7b6476bd61..68bafabcbf 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -35,6 +35,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, + MultipleChoiceModelOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput, @@ -605,7 +606,12 @@ class ModernBertPreTrainedModel(PreTrainedModel): init_weight(module.decoder, stds["out"]) elif isinstance( module, - (ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering), + ( + ModernBertForSequenceClassification, + ModernBertForMultipleChoice, + ModernBertForTokenClassification, + ModernBertForQuestionAnswering, + ), ): init_weight(module.classifier, stds["final_out"]) elif isinstance(module, nn.LayerNorm): @@ -1393,6 +1399,123 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel): ) +@auto_docstring( + custom_intro=""" + The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks. + """ +) +class ModernBertForMultipleChoice(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers when not using Flash Attention. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences including padding tokens. Used to pad the output tensors. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + self._maybe_set_compile() + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + if self.config.classifier_pooling == "cls": + last_hidden_state = last_hidden_state[:, 0] + elif self.config.classifier_pooling == "mean": + last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( + dim=1, keepdim=True + ) + + pooled_output = self.head(last_hidden_state) + pooled_output = self.drop(pooled_output) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + __all__ = [ "ModernBertModel", "ModernBertPreTrainedModel", @@ -1400,4 +1523,5 @@ __all__ = [ "ModernBertForSequenceClassification", "ModernBertForTokenClassification", "ModernBertForQuestionAnswering", + "ModernBertForMultipleChoice", ] diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 3648e30ac9..7e1d5fbfb1 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -31,6 +31,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, + MultipleChoiceModelOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput, @@ -805,7 +806,12 @@ class ModernBertPreTrainedModel(PreTrainedModel): init_weight(module.decoder, stds["out"]) elif isinstance( module, - (ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering), + ( + ModernBertForSequenceClassification, + ModernBertForMultipleChoice, + ModernBertForTokenClassification, + ModernBertForQuestionAnswering, + ), ): init_weight(module.classifier, stds["final_out"]) elif isinstance(module, nn.LayerNorm): @@ -1521,6 +1527,123 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel): ) +@auto_docstring( + custom_intro=""" + The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks. + """ +) +class ModernBertForMultipleChoice(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.config = config + + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers + perform global attention, while the rest perform local attention. This mask is used to avoid attending to + far-away tokens in the local attention layers when not using Flash Attention. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. + indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): + Indices of the non-padding tokens in the input sequence. Used for unpadding the output. + cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): + Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. + max_seqlen (`int`, *optional*): + Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. + batch_size (`int`, *optional*): + Batch size of the input sequences. Used to pad the output tensors. + seq_len (`int`, *optional*): + Sequence length of the input sequences including padding tokens. Used to pad the output tensors. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + self._maybe_set_compile() + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + if self.config.classifier_pooling == "cls": + last_hidden_state = last_hidden_state[:, 0] + elif self.config.classifier_pooling == "mean": + last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( + dim=1, keepdim=True + ) + + pooled_output = self.head(last_hidden_state) + pooled_output = self.drop(pooled_output) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + __all__ = [ "ModernBertConfig", "ModernBertModel", @@ -1529,4 +1652,5 @@ __all__ = [ "ModernBertForSequenceClassification", "ModernBertForTokenClassification", "ModernBertForQuestionAnswering", + "ModernBertForMultipleChoice", ] diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index a6bb02e8b8..94445d46f9 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -41,6 +41,7 @@ if is_torch_available(): from transformers import ( MODEL_FOR_PRETRAINING_MAPPING, ModernBertForMaskedLM, + ModernBertForMultipleChoice, ModernBertForQuestionAnswering, ModernBertForSequenceClassification, ModernBertForTokenClassification, @@ -202,6 +203,22 @@ class ModernBertModelTester: result = model(input_ids, attention_mask=input_mask, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) + def create_and_check_for_multiple_choice( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = ModernBertForMultipleChoice(config=config) + model.to(torch_device) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + result = model( + multiple_choice_inputs_ids, + attention_mask=multiple_choice_input_mask, + labels=choice_labels, + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -227,6 +244,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering, + ModernBertForMultipleChoice, ) if is_torch_available() else () @@ -298,6 +316,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering, + ModernBertForMultipleChoice, ] ): self.assertIn( @@ -318,6 +337,10 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) + def test_for_warning_if_padding_and_no_attention_mask(self): ( config,