Add support for ModernBertForMultipleChoice (#39232)
* implement ModernBertForMultipleChoice * fixup, style, repo consistency * generate modeling_modernbert * add tests + docs * fix test
This commit is contained in:
@@ -41,6 +41,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
ModernBertForMaskedLM,
|
||||
ModernBertForMultipleChoice,
|
||||
ModernBertForQuestionAnswering,
|
||||
ModernBertForSequenceClassification,
|
||||
ModernBertForTokenClassification,
|
||||
@@ -202,6 +203,22 @@ class ModernBertModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_for_multiple_choice(
|
||||
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = ModernBertForMultipleChoice(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
labels=choice_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -227,6 +244,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
ModernBertForSequenceClassification,
|
||||
ModernBertForTokenClassification,
|
||||
ModernBertForQuestionAnswering,
|
||||
ModernBertForMultipleChoice,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
@@ -298,6 +316,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
ModernBertForSequenceClassification,
|
||||
ModernBertForTokenClassification,
|
||||
ModernBertForQuestionAnswering,
|
||||
ModernBertForMultipleChoice,
|
||||
]
|
||||
):
|
||||
self.assertIn(
|
||||
@@ -318,6 +337,10 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def test_for_warning_if_padding_and_no_attention_mask(self):
|
||||
(
|
||||
config,
|
||||
|
||||
Reference in New Issue
Block a user