Add support for ModernBertForMultipleChoice (#39232)

* implement ModernBertForMultipleChoice

* fixup, style, repo consistency

* generate modeling_modernbert

* add tests + docs

* fix test
This commit is contained in:
Jan Netík
2025-08-04 20:45:43 +02:00
committed by GitHub
parent 801e869b67
commit 0bd91cc822
5 changed files with 279 additions and 2 deletions

View File

@@ -41,6 +41,7 @@ if is_torch_available():
from transformers import (
MODEL_FOR_PRETRAINING_MAPPING,
ModernBertForMaskedLM,
ModernBertForMultipleChoice,
ModernBertForQuestionAnswering,
ModernBertForSequenceClassification,
ModernBertForTokenClassification,
@@ -202,6 +203,22 @@ class ModernBertModelTester:
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_for_multiple_choice(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = ModernBertForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
result = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
labels=choice_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -227,6 +244,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
ModernBertForSequenceClassification,
ModernBertForTokenClassification,
ModernBertForQuestionAnswering,
ModernBertForMultipleChoice,
)
if is_torch_available()
else ()
@@ -298,6 +316,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
ModernBertForSequenceClassification,
ModernBertForTokenClassification,
ModernBertForQuestionAnswering,
ModernBertForMultipleChoice,
]
):
self.assertIn(
@@ -318,6 +337,10 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
def test_for_warning_if_padding_and_no_attention_mask(self):
(
config,