Support QuestionAnswering Module for ModernBert based models. (#35566)
* push ModernBertForQuestionAnswering * update ModernBertForQuestionAnswering * update __init__ loading * set imports for ModernBertForQuestionAnswering * update ModernBertForQuestionAnswering * remove debugging logs * update init_weights method * remove custom initialization for ModernBertForQuestionAnswering * apply make fix-copies * apply make style * apply make fix-copies * append ModernBertForQuestionAnswering to the pipeline supported models * remove unused file * remove invalid autoload value * update en/model_doc/modernbert.md * apply make fixup command * make fixup * Update dummies * update usage tips for ModernBertForQuestionAnswering * update usage tips for ModernBertForQuestionAnswering * add init * add lint * add consistency * update init test * change text to trigger stuck text * use self.loss_function instead of custom loss By @Cyrilvallez Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> * Update modeling_modernbert.py make comparable commit to even it out * Match whitespace * whitespace --------- Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Orion Weller <wellerorion@gmail.com> Co-authored-by: Orion Weller <31665361+orionw@users.noreply.github.com> Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
@@ -3047,6 +3047,7 @@ else:
|
||||
_import_structure["models.modernbert"].extend(
|
||||
[
|
||||
"ModernBertForMaskedLM",
|
||||
"ModernBertForQuestionAnswering",
|
||||
"ModernBertForSequenceClassification",
|
||||
"ModernBertForTokenClassification",
|
||||
"ModernBertModel",
|
||||
@@ -7967,6 +7968,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.modernbert import (
|
||||
ModernBertForMaskedLM,
|
||||
ModernBertForQuestionAnswering,
|
||||
ModernBertForSequenceClassification,
|
||||
ModernBertForTokenClassification,
|
||||
ModernBertModel,
|
||||
|
||||
@@ -1138,6 +1138,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
("mistral", "MistralForQuestionAnswering"),
|
||||
("mixtral", "MixtralForQuestionAnswering"),
|
||||
("mobilebert", "MobileBertForQuestionAnswering"),
|
||||
("modernbert", "ModernBertForQuestionAnswering"),
|
||||
("mpnet", "MPNetForQuestionAnswering"),
|
||||
("mpt", "MptForQuestionAnswering"),
|
||||
("mra", "MraForQuestionAnswering"),
|
||||
|
||||
@@ -30,7 +30,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
MaskedLMOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
@@ -650,7 +656,10 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
init_weight(module.dense, stds["out"])
|
||||
elif isinstance(module, ModernBertForMaskedLM):
|
||||
init_weight(module.decoder, stds["out"])
|
||||
elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)):
|
||||
elif isinstance(
|
||||
module,
|
||||
(ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering),
|
||||
):
|
||||
init_weight(module.classifier, stds["final_out"])
|
||||
|
||||
@classmethod
|
||||
@@ -1384,10 +1393,98 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The ModernBert Model with a span classification head on top for extractive question-answering tasks like SQuAD
|
||||
(a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||
""",
|
||||
MODERNBERT_START_DOCSTRING,
|
||||
)
|
||||
class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
||||
def __init__(self, config: ModernBertConfig):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.model = ModernBertModel(config)
|
||||
self.head = ModernBertPredictionHead(config)
|
||||
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=QuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
start_positions: Optional[torch.Tensor] = None,
|
||||
end_positions: Optional[torch.Tensor] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
seq_len: Optional[int] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
self._maybe_set_compile()
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window_mask=sliding_window_mask,
|
||||
position_ids=position_ids,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
last_hidden_state = outputs[0]
|
||||
|
||||
last_hidden_state = self.head(last_hidden_state)
|
||||
last_hidden_state = self.drop(last_hidden_state)
|
||||
logits = self.classifier(last_hidden_state)
|
||||
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ModernBertModel",
|
||||
"ModernBertPreTrainedModel",
|
||||
"ModernBertForMaskedLM",
|
||||
"ModernBertForSequenceClassification",
|
||||
"ModernBertForTokenClassification",
|
||||
"ModernBertForQuestionAnswering",
|
||||
]
|
||||
|
||||
@@ -29,6 +29,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
MaskedLMOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
@@ -825,7 +826,10 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
||||
init_weight(module.dense, stds["out"])
|
||||
elif isinstance(module, ModernBertForMaskedLM):
|
||||
init_weight(module.decoder, stds["out"])
|
||||
elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)):
|
||||
elif isinstance(
|
||||
module,
|
||||
(ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering),
|
||||
):
|
||||
init_weight(module.classifier, stds["final_out"])
|
||||
|
||||
@classmethod
|
||||
@@ -1487,6 +1491,93 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The ModernBert Model with a span classification head on top for extractive question-answering tasks like SQuAD
|
||||
(a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||
""",
|
||||
MODERNBERT_START_DOCSTRING,
|
||||
)
|
||||
class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
||||
def __init__(self, config: ModernBertConfig):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.model = ModernBertModel(config)
|
||||
self.head = ModernBertPredictionHead(config)
|
||||
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=QuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
start_positions: Optional[torch.Tensor] = None,
|
||||
end_positions: Optional[torch.Tensor] = None,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
seq_len: Optional[int] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
self._maybe_set_compile()
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window_mask=sliding_window_mask,
|
||||
position_ids=position_ids,
|
||||
indices=indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
last_hidden_state = outputs[0]
|
||||
|
||||
last_hidden_state = self.head(last_hidden_state)
|
||||
last_hidden_state = self.drop(last_hidden_state)
|
||||
logits = self.classifier(last_hidden_state)
|
||||
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ModernBertConfig",
|
||||
"ModernBertModel",
|
||||
@@ -1494,4 +1585,5 @@ __all__ = [
|
||||
"ModernBertForMaskedLM",
|
||||
"ModernBertForSequenceClassification",
|
||||
"ModernBertForTokenClassification",
|
||||
"ModernBertForQuestionAnswering",
|
||||
]
|
||||
|
||||
@@ -6692,6 +6692,13 @@ class ModernBertForMaskedLM(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ModernBertForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ModernBertForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user