Add support for ModernBertForMultipleChoice (#39232)
* implement ModernBertForMultipleChoice * fixup, style, repo consistency * generate modeling_modernbert * add tests + docs * fix test
This commit is contained in:
@@ -115,6 +115,11 @@ echo -e "Plants create [MASK] through a process known as photosynthesis." | tran
|
||||
[[autodoc]] ModernBertForTokenClassification
|
||||
- forward
|
||||
|
||||
## ModernBertForMultipleChoice
|
||||
|
||||
[[autodoc]] ModernBertForMultipleChoice
|
||||
- forward
|
||||
|
||||
## ModernBertForQuestionAnswering
|
||||
|
||||
[[autodoc]] ModernBertForQuestionAnswering
|
||||
|
||||
@@ -1469,6 +1469,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
||||
("mega", "MegaForMultipleChoice"),
|
||||
("megatron-bert", "MegatronBertForMultipleChoice"),
|
||||
("mobilebert", "MobileBertForMultipleChoice"),
|
||||
("modernbert", "ModernBertForMultipleChoice"),
|
||||
("mpnet", "MPNetForMultipleChoice"),
|
||||
("mra", "MraForMultipleChoice"),
|
||||
("nezha", "NezhaForMultipleChoice"),
|
||||
|
||||
@@ -35,6 +35,7 @@ from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
MaskedLMOutput,
|
||||
MultipleChoiceModelOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
@@ -605,7 +606,12 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
init_weight(module.decoder, stds["out"])
|
||||
elif isinstance(
|
||||
module,
|
||||
(ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering),
|
||||
(
|
||||
ModernBertForSequenceClassification,
|
||||
ModernBertForMultipleChoice,
|
||||
ModernBertForTokenClassification,
|
||||
ModernBertForQuestionAnswering,
|
||||
),
|
||||
):
|
||||
init_weight(module.classifier, stds["final_out"])
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
@@ -1393,6 +1399,123 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
|
||||
"""
|
||||
)
|
||||
class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
|
||||
def __init__(self, config: ModernBertConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.model = ModernBertModel(config)
|
||||
self.head = ModernBertPredictionHead(config)
|
||||
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
seq_len: Optional[int] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
||||
r"""
|
||||
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
||||
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
||||
far-away tokens in the local attention layers when not using Flash Attention.
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
||||
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
|
||||
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
||||
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
||||
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
||||
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
||||
max_seqlen (`int`, *optional*):
|
||||
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
||||
batch_size (`int`, *optional*):
|
||||
Batch size of the input sequences. Used to pad the output tensors.
|
||||
seq_len (`int`, *optional*):
|
||||
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
||||
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||
inputs_embeds = (
|
||||
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self._maybe_set_compile()
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window_mask=sliding_window_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
last_hidden_state = outputs[0]
|
||||
|
||||
if self.config.classifier_pooling == "cls":
|
||||
last_hidden_state = last_hidden_state[:, 0]
|
||||
elif self.config.classifier_pooling == "mean":
|
||||
last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
|
||||
dim=1, keepdim=True
|
||||
)
|
||||
|
||||
pooled_output = self.head(last_hidden_state)
|
||||
pooled_output = self.drop(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
reshaped_logits = logits.view(-1, num_choices)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(reshaped_logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return MultipleChoiceModelOutput(
|
||||
loss=loss,
|
||||
logits=reshaped_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ModernBertModel",
|
||||
"ModernBertPreTrainedModel",
|
||||
@@ -1400,4 +1523,5 @@ __all__ = [
|
||||
"ModernBertForSequenceClassification",
|
||||
"ModernBertForTokenClassification",
|
||||
"ModernBertForQuestionAnswering",
|
||||
"ModernBertForMultipleChoice",
|
||||
]
|
||||
|
||||
@@ -31,6 +31,7 @@ from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
MaskedLMOutput,
|
||||
MultipleChoiceModelOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
@@ -805,7 +806,12 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
init_weight(module.decoder, stds["out"])
|
||||
elif isinstance(
|
||||
module,
|
||||
(ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering),
|
||||
(
|
||||
ModernBertForSequenceClassification,
|
||||
ModernBertForMultipleChoice,
|
||||
ModernBertForTokenClassification,
|
||||
ModernBertForQuestionAnswering,
|
||||
),
|
||||
):
|
||||
init_weight(module.classifier, stds["final_out"])
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
@@ -1521,6 +1527,123 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
|
||||
"""
|
||||
)
|
||||
class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
|
||||
def __init__(self, config: ModernBertConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.model = ModernBertModel(config)
|
||||
self.head = ModernBertPredictionHead(config)
|
||||
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
seq_len: Optional[int] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
||||
r"""
|
||||
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
||||
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
||||
far-away tokens in the local attention layers when not using Flash Attention.
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
||||
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
|
||||
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
||||
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
||||
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
||||
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
||||
max_seqlen (`int`, *optional*):
|
||||
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
||||
batch_size (`int`, *optional*):
|
||||
Batch size of the input sequences. Used to pad the output tensors.
|
||||
seq_len (`int`, *optional*):
|
||||
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
||||
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||
inputs_embeds = (
|
||||
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self._maybe_set_compile()
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window_mask=sliding_window_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
last_hidden_state = outputs[0]
|
||||
|
||||
if self.config.classifier_pooling == "cls":
|
||||
last_hidden_state = last_hidden_state[:, 0]
|
||||
elif self.config.classifier_pooling == "mean":
|
||||
last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
|
||||
dim=1, keepdim=True
|
||||
)
|
||||
|
||||
pooled_output = self.head(last_hidden_state)
|
||||
pooled_output = self.drop(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
reshaped_logits = logits.view(-1, num_choices)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(reshaped_logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return MultipleChoiceModelOutput(
|
||||
loss=loss,
|
||||
logits=reshaped_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ModernBertConfig",
|
||||
"ModernBertModel",
|
||||
@@ -1529,4 +1652,5 @@ __all__ = [
|
||||
"ModernBertForSequenceClassification",
|
||||
"ModernBertForTokenClassification",
|
||||
"ModernBertForQuestionAnswering",
|
||||
"ModernBertForMultipleChoice",
|
||||
]
|
||||
|
||||
@@ -41,6 +41,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
ModernBertForMaskedLM,
|
||||
ModernBertForMultipleChoice,
|
||||
ModernBertForQuestionAnswering,
|
||||
ModernBertForSequenceClassification,
|
||||
ModernBertForTokenClassification,
|
||||
@@ -202,6 +203,22 @@ class ModernBertModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_for_multiple_choice(
|
||||
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = ModernBertForMultipleChoice(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
labels=choice_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -227,6 +244,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
ModernBertForSequenceClassification,
|
||||
ModernBertForTokenClassification,
|
||||
ModernBertForQuestionAnswering,
|
||||
ModernBertForMultipleChoice,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
@@ -298,6 +316,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
ModernBertForSequenceClassification,
|
||||
ModernBertForTokenClassification,
|
||||
ModernBertForQuestionAnswering,
|
||||
ModernBertForMultipleChoice,
|
||||
]
|
||||
):
|
||||
self.assertIn(
|
||||
@@ -318,6 +337,10 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def test_for_warning_if_padding_and_no_attention_mask(self):
|
||||
(
|
||||
config,
|
||||
|
||||
Reference in New Issue
Block a user