From 49b5ab6a27511de5168c72e83318164f1b4adc43 Mon Sep 17 00:00:00 2001 From: Abu Bakr Soliman Date: Wed, 26 Mar 2025 22:24:18 +0200 Subject: [PATCH] Support QuestionAnswering Module for ModernBert based models. (#35566) * push ModernBertForQuestionAnswering * update ModernBertForQuestionAnswering * update __init__ loading * set imports for ModernBertForQuestionAnswering * update ModernBertForQuestionAnswering * remove debugging logs * update init_weights method * remove custom initialization for ModernBertForQuestionAnswering * apply make fix-copies * apply make style * apply make fix-copies * append ModernBertForQuestionAnswering to the pipeline supported models * remove unused file * remove invalid autoload value * update en/model_doc/modernbert.md * apply make fixup command * make fixup * Update dummies * update usage tips for ModernBertForQuestionAnswering * update usage tips for ModernBertForQuestionAnswering * add init * add lint * add consistency * update init test * change text to trigger stuck text * use self.loss_function instead of custom loss By @Cyrilvallez Co-authored-by: Cyril Vallez * Update modeling_modernbert.py make comparable commit to even it out * Match whitespace * whitespace --------- Co-authored-by: Matt Co-authored-by: Orion Weller Co-authored-by: Orion Weller <31665361+orionw@users.noreply.github.com> Co-authored-by: Cyril Vallez --- docs/source/en/model_doc/modernbert.md | 13 +++ src/transformers/__init__.py | 2 + src/transformers/models/auto/modeling_auto.py | 1 + .../models/modernbert/modeling_modernbert.py | 101 +++++++++++++++++- .../models/modernbert/modular_modernbert.py | 94 +++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 7 ++ .../modernbert/test_modeling_modernbert.py | 10 +- tests/test_modeling_common.py | 1 + 8 files changed, 225 insertions(+), 4 deletions(-) diff --git a/docs/source/en/model_doc/modernbert.md b/docs/source/en/model_doc/modernbert.md index f7ceaae187..a54adb04c2 100644 --- a/docs/source/en/model_doc/modernbert.md +++ b/docs/source/en/model_doc/modernbert.md @@ -60,6 +60,9 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h - [Masked language modeling task guide](../tasks/masked_language_modeling) + + +- [`ModernBertForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering) and [colab notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/question_answering.ipynb). ## ModernBertConfig @@ -88,5 +91,15 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] ModernBertForTokenClassification - forward +## ModernBertForQuestionAnswering + +[[autodoc]] ModernBertForQuestionAnswering + - forward + +### Usage tips + +The ModernBert model can be fine-tuned using the HuggingFace Transformers library with its [official script](https://github.com/huggingface/transformers/blob/main/examples/pytorch/question-answering/run_qa.py) for question-answering tasks. + + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2db51e4770..03ed405628 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3047,6 +3047,7 @@ else: _import_structure["models.modernbert"].extend( [ "ModernBertForMaskedLM", + "ModernBertForQuestionAnswering", "ModernBertForSequenceClassification", "ModernBertForTokenClassification", "ModernBertModel", @@ -7967,6 +7968,7 @@ if TYPE_CHECKING: ) from .models.modernbert import ( ModernBertForMaskedLM, + ModernBertForQuestionAnswering, ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertModel, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 05a4157414..add0619aa8 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1138,6 +1138,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ("mistral", "MistralForQuestionAnswering"), ("mixtral", "MixtralForQuestionAnswering"), ("mobilebert", "MobileBertForQuestionAnswering"), + ("modernbert", "ModernBertForQuestionAnswering"), ("mpnet", "MPNetForQuestionAnswering"), ("mpt", "MptForQuestionAnswering"), ("mra", "MraForQuestionAnswering"), diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 3d39041253..28ba79b830 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -30,7 +30,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput +from ...modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -650,7 +656,10 @@ class ModernBertPreTrainedModel(PreTrainedModel): init_weight(module.dense, stds["out"]) elif isinstance(module, ModernBertForMaskedLM): init_weight(module.decoder, stds["out"]) - elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)): + elif isinstance( + module, + (ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering), + ): init_weight(module.classifier, stds["final_out"]) @classmethod @@ -1384,10 +1393,98 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel): ) +@add_start_docstrings( + """ + The ModernBert Model with a span classification head on top for extractive question-answering tasks like SQuAD + (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + MODERNBERT_START_DOCSTRING, +) +class ModernBertForQuestionAnswering(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.post_init() + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + last_hidden_state = self.head(last_hidden_state) + last_hidden_state = self.drop(last_hidden_state) + logits = self.classifier(last_hidden_state) + + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + __all__ = [ "ModernBertModel", "ModernBertPreTrainedModel", "ModernBertForMaskedLM", "ModernBertForSequenceClassification", "ModernBertForTokenClassification", + "ModernBertForQuestionAnswering", ] diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index edfdc94346..0901662f66 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -29,6 +29,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, + QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput, ) @@ -825,7 +826,10 @@ class ModernBertPreTrainedModel(PreTrainedModel): init_weight(module.dense, stds["out"]) elif isinstance(module, ModernBertForMaskedLM): init_weight(module.decoder, stds["out"]) - elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)): + elif isinstance( + module, + (ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering), + ): init_weight(module.classifier, stds["final_out"]) @classmethod @@ -1487,6 +1491,93 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel): ) +@add_start_docstrings( + """ + The ModernBert Model with a span classification head on top for extractive question-answering tasks like SQuAD + (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + MODERNBERT_START_DOCSTRING, +) +class ModernBertForQuestionAnswering(ModernBertPreTrainedModel): + def __init__(self, config: ModernBertConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = ModernBertModel(config) + self.head = ModernBertPredictionHead(config) + self.drop = torch.nn.Dropout(config.classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.post_init() + + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + sliding_window_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + batch_size: Optional[int] = None, + seq_len: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + self._maybe_set_compile() + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=batch_size, + seq_len=seq_len, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + + last_hidden_state = self.head(last_hidden_state) + last_hidden_state = self.drop(last_hidden_state) + logits = self.classifier(last_hidden_state) + + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + __all__ = [ "ModernBertConfig", "ModernBertModel", @@ -1494,4 +1585,5 @@ __all__ = [ "ModernBertForMaskedLM", "ModernBertForSequenceClassification", "ModernBertForTokenClassification", + "ModernBertForQuestionAnswering", ] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 744a316161..b52d2bed78 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6692,6 +6692,13 @@ class ModernBertForMaskedLM(metaclass=DummyObject): requires_backends(self, ["torch"]) +class ModernBertForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ModernBertForSequenceClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index 238999da91..14882b0879 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -40,6 +40,7 @@ if is_torch_available(): from transformers import ( MODEL_FOR_PRETRAINING_MAPPING, ModernBertForMaskedLM, + ModernBertForQuestionAnswering, ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertModel, @@ -224,6 +225,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ModernBertForMaskedLM, ModernBertForSequenceClassification, ModernBertForTokenClassification, + ModernBertForQuestionAnswering, ) if is_torch_available() else () @@ -235,6 +237,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa "text-classification": ModernBertForSequenceClassification, "token-classification": ModernBertForTokenClassification, "zero-shot": ModernBertForSequenceClassification, + "question-answering": ModernBertForQuestionAnswering, } if is_torch_available() else {} @@ -289,7 +292,12 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa # are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init if param.requires_grad and not ( name == "classifier.weight" - and model_class in [ModernBertForSequenceClassification, ModernBertForTokenClassification] + and model_class + in [ + ModernBertForSequenceClassification, + ModernBertForTokenClassification, + ModernBertForQuestionAnswering, + ] ): self.assertIn( ((param.data.mean() * 1e9).round() / 1e9).item(), diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 5454140a68..76bbf2766f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3086,6 +3086,7 @@ class ModelTesterMixin: "ModernBertForSequenceClassification", "ModernBertForTokenClassification", "TimmWrapperForImageClassification", + "ModernBertForQuestionAnswering", ] special_param_names = [ r"^bit\.",