Support QuestionAnswering Module for ModernBert based models. (#35566)
* push ModernBertForQuestionAnswering * update ModernBertForQuestionAnswering * update __init__ loading * set imports for ModernBertForQuestionAnswering * update ModernBertForQuestionAnswering * remove debugging logs * update init_weights method * remove custom initialization for ModernBertForQuestionAnswering * apply make fix-copies * apply make style * apply make fix-copies * append ModernBertForQuestionAnswering to the pipeline supported models * remove unused file * remove invalid autoload value * update en/model_doc/modernbert.md * apply make fixup command * make fixup * Update dummies * update usage tips for ModernBertForQuestionAnswering * update usage tips for ModernBertForQuestionAnswering * add init * add lint * add consistency * update init test * change text to trigger stuck text * use self.loss_function instead of custom loss By @Cyrilvallez Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> * Update modeling_modernbert.py make comparable commit to even it out * Match whitespace * whitespace --------- Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Orion Weller <wellerorion@gmail.com> Co-authored-by: Orion Weller <31665361+orionw@users.noreply.github.com> Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
@@ -60,6 +60,9 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||||||
|
|
||||||
- [Masked language modeling task guide](../tasks/masked_language_modeling)
|
- [Masked language modeling task guide](../tasks/masked_language_modeling)
|
||||||
|
|
||||||
|
<PipelineTag pipeline="question-answering"/>
|
||||||
|
|
||||||
|
- [`ModernBertForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering) and [colab notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/question_answering.ipynb).
|
||||||
|
|
||||||
## ModernBertConfig
|
## ModernBertConfig
|
||||||
|
|
||||||
@@ -88,5 +91,15 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||||||
[[autodoc]] ModernBertForTokenClassification
|
[[autodoc]] ModernBertForTokenClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## ModernBertForQuestionAnswering
|
||||||
|
|
||||||
|
[[autodoc]] ModernBertForQuestionAnswering
|
||||||
|
- forward
|
||||||
|
|
||||||
|
### Usage tips
|
||||||
|
|
||||||
|
The ModernBert model can be fine-tuned using the HuggingFace Transformers library with its [official script](https://github.com/huggingface/transformers/blob/main/examples/pytorch/question-answering/run_qa.py) for question-answering tasks.
|
||||||
|
|
||||||
|
|
||||||
</pt>
|
</pt>
|
||||||
</frameworkcontent>
|
</frameworkcontent>
|
||||||
|
|||||||
@@ -3047,6 +3047,7 @@ else:
|
|||||||
_import_structure["models.modernbert"].extend(
|
_import_structure["models.modernbert"].extend(
|
||||||
[
|
[
|
||||||
"ModernBertForMaskedLM",
|
"ModernBertForMaskedLM",
|
||||||
|
"ModernBertForQuestionAnswering",
|
||||||
"ModernBertForSequenceClassification",
|
"ModernBertForSequenceClassification",
|
||||||
"ModernBertForTokenClassification",
|
"ModernBertForTokenClassification",
|
||||||
"ModernBertModel",
|
"ModernBertModel",
|
||||||
@@ -7967,6 +7968,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.modernbert import (
|
from .models.modernbert import (
|
||||||
ModernBertForMaskedLM,
|
ModernBertForMaskedLM,
|
||||||
|
ModernBertForQuestionAnswering,
|
||||||
ModernBertForSequenceClassification,
|
ModernBertForSequenceClassification,
|
||||||
ModernBertForTokenClassification,
|
ModernBertForTokenClassification,
|
||||||
ModernBertModel,
|
ModernBertModel,
|
||||||
|
|||||||
@@ -1138,6 +1138,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|||||||
("mistral", "MistralForQuestionAnswering"),
|
("mistral", "MistralForQuestionAnswering"),
|
||||||
("mixtral", "MixtralForQuestionAnswering"),
|
("mixtral", "MixtralForQuestionAnswering"),
|
||||||
("mobilebert", "MobileBertForQuestionAnswering"),
|
("mobilebert", "MobileBertForQuestionAnswering"),
|
||||||
|
("modernbert", "ModernBertForQuestionAnswering"),
|
||||||
("mpnet", "MPNetForQuestionAnswering"),
|
("mpnet", "MPNetForQuestionAnswering"),
|
||||||
("mpt", "MptForQuestionAnswering"),
|
("mpt", "MptForQuestionAnswering"),
|
||||||
("mra", "MraForQuestionAnswering"),
|
("mra", "MraForQuestionAnswering"),
|
||||||
|
|||||||
@@ -30,7 +30,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
|
from ...modeling_outputs import (
|
||||||
|
BaseModelOutput,
|
||||||
|
MaskedLMOutput,
|
||||||
|
QuestionAnsweringModelOutput,
|
||||||
|
SequenceClassifierOutput,
|
||||||
|
TokenClassifierOutput,
|
||||||
|
)
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@@ -650,7 +656,10 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
init_weight(module.dense, stds["out"])
|
init_weight(module.dense, stds["out"])
|
||||||
elif isinstance(module, ModernBertForMaskedLM):
|
elif isinstance(module, ModernBertForMaskedLM):
|
||||||
init_weight(module.decoder, stds["out"])
|
init_weight(module.decoder, stds["out"])
|
||||||
elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)):
|
elif isinstance(
|
||||||
|
module,
|
||||||
|
(ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering),
|
||||||
|
):
|
||||||
init_weight(module.classifier, stds["final_out"])
|
init_weight(module.classifier, stds["final_out"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -1384,10 +1393,98 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The ModernBert Model with a span classification head on top for extractive question-answering tasks like SQuAD
|
||||||
|
(a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||||
|
""",
|
||||||
|
MODERNBERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
||||||
|
def __init__(self, config: ModernBertConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.model = ModernBertModel(config)
|
||||||
|
self.head = ModernBertPredictionHead(config)
|
||||||
|
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
output_type=QuestionAnsweringModelOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
start_positions: Optional[torch.Tensor] = None,
|
||||||
|
end_positions: Optional[torch.Tensor] = None,
|
||||||
|
indices: Optional[torch.Tensor] = None,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[int] = None,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
seq_len: Optional[int] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
self._maybe_set_compile()
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
sliding_window_mask=sliding_window_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
indices=indices,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
batch_size=batch_size,
|
||||||
|
seq_len=seq_len,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
last_hidden_state = outputs[0]
|
||||||
|
|
||||||
|
last_hidden_state = self.head(last_hidden_state)
|
||||||
|
last_hidden_state = self.drop(last_hidden_state)
|
||||||
|
logits = self.classifier(last_hidden_state)
|
||||||
|
|
||||||
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
|
start_logits = start_logits.squeeze(-1).contiguous()
|
||||||
|
end_logits = end_logits.squeeze(-1).contiguous()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (start_logits, end_logits) + outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return QuestionAnsweringModelOutput(
|
||||||
|
loss=loss,
|
||||||
|
start_logits=start_logits,
|
||||||
|
end_logits=end_logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModernBertModel",
|
"ModernBertModel",
|
||||||
"ModernBertPreTrainedModel",
|
"ModernBertPreTrainedModel",
|
||||||
"ModernBertForMaskedLM",
|
"ModernBertForMaskedLM",
|
||||||
"ModernBertForSequenceClassification",
|
"ModernBertForSequenceClassification",
|
||||||
"ModernBertForTokenClassification",
|
"ModernBertForTokenClassification",
|
||||||
|
"ModernBertForQuestionAnswering",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
|||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
MaskedLMOutput,
|
MaskedLMOutput,
|
||||||
|
QuestionAnsweringModelOutput,
|
||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
@@ -825,7 +826,10 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
init_weight(module.dense, stds["out"])
|
init_weight(module.dense, stds["out"])
|
||||||
elif isinstance(module, ModernBertForMaskedLM):
|
elif isinstance(module, ModernBertForMaskedLM):
|
||||||
init_weight(module.decoder, stds["out"])
|
init_weight(module.decoder, stds["out"])
|
||||||
elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)):
|
elif isinstance(
|
||||||
|
module,
|
||||||
|
(ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering),
|
||||||
|
):
|
||||||
init_weight(module.classifier, stds["final_out"])
|
init_weight(module.classifier, stds["final_out"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -1487,6 +1491,93 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The ModernBert Model with a span classification head on top for extractive question-answering tasks like SQuAD
|
||||||
|
(a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||||
|
""",
|
||||||
|
MODERNBERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
||||||
|
def __init__(self, config: ModernBertConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.model = ModernBertModel(config)
|
||||||
|
self.head = ModernBertPredictionHead(config)
|
||||||
|
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
output_type=QuestionAnsweringModelOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
start_positions: Optional[torch.Tensor] = None,
|
||||||
|
end_positions: Optional[torch.Tensor] = None,
|
||||||
|
indices: Optional[torch.Tensor] = None,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[int] = None,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
seq_len: Optional[int] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
self._maybe_set_compile()
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
sliding_window_mask=sliding_window_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
indices=indices,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
batch_size=batch_size,
|
||||||
|
seq_len=seq_len,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
last_hidden_state = outputs[0]
|
||||||
|
|
||||||
|
last_hidden_state = self.head(last_hidden_state)
|
||||||
|
last_hidden_state = self.drop(last_hidden_state)
|
||||||
|
logits = self.classifier(last_hidden_state)
|
||||||
|
|
||||||
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
|
start_logits = start_logits.squeeze(-1).contiguous()
|
||||||
|
end_logits = end_logits.squeeze(-1).contiguous()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (start_logits, end_logits) + outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return QuestionAnsweringModelOutput(
|
||||||
|
loss=loss,
|
||||||
|
start_logits=start_logits,
|
||||||
|
end_logits=end_logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModernBertConfig",
|
"ModernBertConfig",
|
||||||
"ModernBertModel",
|
"ModernBertModel",
|
||||||
@@ -1494,4 +1585,5 @@ __all__ = [
|
|||||||
"ModernBertForMaskedLM",
|
"ModernBertForMaskedLM",
|
||||||
"ModernBertForSequenceClassification",
|
"ModernBertForSequenceClassification",
|
||||||
"ModernBertForTokenClassification",
|
"ModernBertForTokenClassification",
|
||||||
|
"ModernBertForQuestionAnswering",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -6692,6 +6692,13 @@ class ModernBertForMaskedLM(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class ModernBertForQuestionAnswering(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class ModernBertForSequenceClassification(metaclass=DummyObject):
|
class ModernBertForSequenceClassification(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
MODEL_FOR_PRETRAINING_MAPPING,
|
MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
ModernBertForMaskedLM,
|
ModernBertForMaskedLM,
|
||||||
|
ModernBertForQuestionAnswering,
|
||||||
ModernBertForSequenceClassification,
|
ModernBertForSequenceClassification,
|
||||||
ModernBertForTokenClassification,
|
ModernBertForTokenClassification,
|
||||||
ModernBertModel,
|
ModernBertModel,
|
||||||
@@ -224,6 +225,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
ModernBertForMaskedLM,
|
ModernBertForMaskedLM,
|
||||||
ModernBertForSequenceClassification,
|
ModernBertForSequenceClassification,
|
||||||
ModernBertForTokenClassification,
|
ModernBertForTokenClassification,
|
||||||
|
ModernBertForQuestionAnswering,
|
||||||
)
|
)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
@@ -235,6 +237,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
"text-classification": ModernBertForSequenceClassification,
|
"text-classification": ModernBertForSequenceClassification,
|
||||||
"token-classification": ModernBertForTokenClassification,
|
"token-classification": ModernBertForTokenClassification,
|
||||||
"zero-shot": ModernBertForSequenceClassification,
|
"zero-shot": ModernBertForSequenceClassification,
|
||||||
|
"question-answering": ModernBertForQuestionAnswering,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
@@ -289,7 +292,12 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
# are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init
|
# are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init
|
||||||
if param.requires_grad and not (
|
if param.requires_grad and not (
|
||||||
name == "classifier.weight"
|
name == "classifier.weight"
|
||||||
and model_class in [ModernBertForSequenceClassification, ModernBertForTokenClassification]
|
and model_class
|
||||||
|
in [
|
||||||
|
ModernBertForSequenceClassification,
|
||||||
|
ModernBertForTokenClassification,
|
||||||
|
ModernBertForQuestionAnswering,
|
||||||
|
]
|
||||||
):
|
):
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
|||||||
@@ -3086,6 +3086,7 @@ class ModelTesterMixin:
|
|||||||
"ModernBertForSequenceClassification",
|
"ModernBertForSequenceClassification",
|
||||||
"ModernBertForTokenClassification",
|
"ModernBertForTokenClassification",
|
||||||
"TimmWrapperForImageClassification",
|
"TimmWrapperForImageClassification",
|
||||||
|
"ModernBertForQuestionAnswering",
|
||||||
]
|
]
|
||||||
special_param_names = [
|
special_param_names = [
|
||||||
r"^bit\.",
|
r"^bit\.",
|
||||||
|
|||||||
Reference in New Issue
Block a user