From 7434c0ed21a154136b0145b0245ae9058005abac Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Mon, 14 Oct 2024 08:53:32 +0200 Subject: [PATCH] Mistral-related models for QnA (#34045) * mistral qna start * mixtral qna * oops * qwen2 qna * qwen2moe qna * add missing input embed methods * add copied to all methods, can't directly from llama due to the prefix * make top level copied from --- docs/source/en/model_doc/mistral.md | 5 + docs/source/en/model_doc/mixtral.md | 4 + docs/source/en/model_doc/qwen2.md | 5 + docs/source/en/model_doc/qwen2_moe.md | 5 + src/transformers/__init__.py | 8 ++ src/transformers/models/auto/modeling_auto.py | 4 + src/transformers/models/mistral/__init__.py | 2 + .../models/mistral/modeling_mistral.py | 101 ++++++++++++++++++ src/transformers/models/mixtral/__init__.py | 2 + .../models/mixtral/modeling_mixtral.py | 101 ++++++++++++++++++ src/transformers/models/qwen2/__init__.py | 2 + .../models/qwen2/modeling_qwen2.py | 101 ++++++++++++++++++ src/transformers/models/qwen2_moe/__init__.py | 2 + .../models/qwen2_moe/modeling_qwen2_moe.py | 101 ++++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 28 +++++ tests/models/mistral/test_modeling_mistral.py | 10 +- tests/models/mixtral/test_modeling_mixtral.py | 10 +- tests/models/qwen2/test_modeling_qwen2.py | 10 +- .../qwen2_moe/test_modeling_qwen2_moe.py | 10 +- 19 files changed, 507 insertions(+), 4 deletions(-) diff --git a/docs/source/en/model_doc/mistral.md b/docs/source/en/model_doc/mistral.md index 17ce15b2b8..2be657109a 100644 --- a/docs/source/en/model_doc/mistral.md +++ b/docs/source/en/model_doc/mistral.md @@ -208,6 +208,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] MistralForTokenClassification - forward +## MistralForQuestionAnswering + +[[autodoc]] MistralForQuestionAnswering +- forward + ## FlaxMistralModel [[autodoc]] FlaxMistralModel diff --git a/docs/source/en/model_doc/mixtral.md b/docs/source/en/model_doc/mixtral.md index 71c7d7921e..7afcaa798e 100644 --- a/docs/source/en/model_doc/mixtral.md +++ b/docs/source/en/model_doc/mixtral.md @@ -209,3 +209,7 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] MixtralForTokenClassification - forward + +## MixtralForQuestionAnswering +[[autodoc]] MixtralForQuestionAnswering + - forward diff --git a/docs/source/en/model_doc/qwen2.md b/docs/source/en/model_doc/qwen2.md index 16815f2fc1..78138413c7 100644 --- a/docs/source/en/model_doc/qwen2.md +++ b/docs/source/en/model_doc/qwen2.md @@ -85,3 +85,8 @@ In the following, we demonstrate how to use `Qwen2-7B-Instruct` for the inferenc [[autodoc]] Qwen2ForTokenClassification - forward + +## Qwen2ForQuestionAnswering + +[[autodoc]] Qwen2ForQuestionAnswering + - forward diff --git a/docs/source/en/model_doc/qwen2_moe.md b/docs/source/en/model_doc/qwen2_moe.md index 9c6dc80beb..3a7391ca19 100644 --- a/docs/source/en/model_doc/qwen2_moe.md +++ b/docs/source/en/model_doc/qwen2_moe.md @@ -80,3 +80,8 @@ In the following, we demonstrate how to use `Qwen1.5-MoE-A2.7B-Chat` for the inf [[autodoc]] Qwen2MoeForTokenClassification - forward + +## Qwen2MoeForQuestionAnswering + +[[autodoc]] Qwen2MoeForQuestionAnswering + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ab829c6894..daffe11987 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2709,6 +2709,7 @@ else: _import_structure["models.mistral"].extend( [ "MistralForCausalLM", + "MistralForQuestionAnswering", "MistralForSequenceClassification", "MistralForTokenClassification", "MistralModel", @@ -2718,6 +2719,7 @@ else: _import_structure["models.mixtral"].extend( [ "MixtralForCausalLM", + "MixtralForQuestionAnswering", "MixtralForSequenceClassification", "MixtralForTokenClassification", "MixtralModel", @@ -3094,6 +3096,7 @@ else: _import_structure["models.qwen2"].extend( [ "Qwen2ForCausalLM", + "Qwen2ForQuestionAnswering", "Qwen2ForSequenceClassification", "Qwen2ForTokenClassification", "Qwen2Model", @@ -3110,6 +3113,7 @@ else: _import_structure["models.qwen2_moe"].extend( [ "Qwen2MoeForCausalLM", + "Qwen2MoeForQuestionAnswering", "Qwen2MoeForSequenceClassification", "Qwen2MoeForTokenClassification", "Qwen2MoeModel", @@ -7323,6 +7327,7 @@ if TYPE_CHECKING: ) from .models.mistral import ( MistralForCausalLM, + MistralForQuestionAnswering, MistralForSequenceClassification, MistralForTokenClassification, MistralModel, @@ -7330,6 +7335,7 @@ if TYPE_CHECKING: ) from .models.mixtral import ( MixtralForCausalLM, + MixtralForQuestionAnswering, MixtralForSequenceClassification, MixtralForTokenClassification, MixtralModel, @@ -7625,6 +7631,7 @@ if TYPE_CHECKING: ) from .models.qwen2 import ( Qwen2ForCausalLM, + Qwen2ForQuestionAnswering, Qwen2ForSequenceClassification, Qwen2ForTokenClassification, Qwen2Model, @@ -7637,6 +7644,7 @@ if TYPE_CHECKING: ) from .models.qwen2_moe import ( Qwen2MoeForCausalLM, + Qwen2MoeForQuestionAnswering, Qwen2MoeForSequenceClassification, Qwen2MoeForTokenClassification, Qwen2MoeModel, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index aa0d59de52..dbfcccaa46 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1046,6 +1046,8 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ("mbart", "MBartForQuestionAnswering"), ("mega", "MegaForQuestionAnswering"), ("megatron-bert", "MegatronBertForQuestionAnswering"), + ("mistral", "MistralForQuestionAnswering"), + ("mixtral", "MixtralForQuestionAnswering"), ("mobilebert", "MobileBertForQuestionAnswering"), ("mpnet", "MPNetForQuestionAnswering"), ("mpt", "MptForQuestionAnswering"), @@ -1057,6 +1059,8 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ("nystromformer", "NystromformerForQuestionAnswering"), ("opt", "OPTForQuestionAnswering"), ("qdqbert", "QDQBertForQuestionAnswering"), + ("qwen2", "Qwen2ForQuestionAnswering"), + ("qwen2_moe", "Qwen2MoeForQuestionAnswering"), ("reformer", "ReformerForQuestionAnswering"), ("rembert", "RemBertForQuestionAnswering"), ("roberta", "RobertaForQuestionAnswering"), diff --git a/src/transformers/models/mistral/__init__.py b/src/transformers/models/mistral/__init__.py index 93e551e193..31441efe65 100644 --- a/src/transformers/models/mistral/__init__.py +++ b/src/transformers/models/mistral/__init__.py @@ -35,6 +35,7 @@ except OptionalDependencyNotAvailable: else: _import_structure["modeling_mistral"] = [ "MistralForCausalLM", + "MistralForQuestionAnswering", "MistralModel", "MistralPreTrainedModel", "MistralForSequenceClassification", @@ -78,6 +79,7 @@ if TYPE_CHECKING: else: from .modeling_mistral import ( MistralForCausalLM, + MistralForQuestionAnswering, MistralForSequenceClassification, MistralForTokenClassification, MistralModel, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index b0ffe3e56e..1bb7a3f109 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, + QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) @@ -1408,3 +1409,103 @@ class MistralForTokenClassification(MistralPreTrainedModel): hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings( + """ +The Mistral Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + MISTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->Mistral,LLAMA->MISTRAL,transformer->model +class MistralForQuestionAnswering(MistralPreTrainedModel): + base_model_prefix = "model" + + # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Mistral + def __init__(self, config): + super().__init__(config) + self.model = MistralModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/mixtral/__init__.py b/src/transformers/models/mixtral/__init__.py index b124d41dfb..4ee4834dd2 100644 --- a/src/transformers/models/mixtral/__init__.py +++ b/src/transformers/models/mixtral/__init__.py @@ -33,6 +33,7 @@ except OptionalDependencyNotAvailable: else: _import_structure["modeling_mixtral"] = [ "MixtralForCausalLM", + "MixtralForQuestionAnswering", "MixtralModel", "MixtralPreTrainedModel", "MixtralForSequenceClassification", @@ -51,6 +52,7 @@ if TYPE_CHECKING: else: from .modeling_mixtral import ( MixtralForCausalLM, + MixtralForQuestionAnswering, MixtralForSequenceClassification, MixtralForTokenClassification, MixtralModel, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 9c7fadbb8f..9bb0654f03 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -35,6 +35,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_caus from ...modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, + QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) @@ -1644,3 +1645,103 @@ class MixtralForTokenClassification(MixtralPreTrainedModel): hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings( + """ +The Mixtral Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Mixtral, MISTRAL->MIXTRAL +class MixtralForQuestionAnswering(MixtralPreTrainedModel): + base_model_prefix = "model" + + # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Mixtral + def __init__(self, config): + super().__init__(config) + self.model = MixtralModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/qwen2/__init__.py b/src/transformers/models/qwen2/__init__.py index 35df37e91a..301531655a 100644 --- a/src/transformers/models/qwen2/__init__.py +++ b/src/transformers/models/qwen2/__init__.py @@ -42,6 +42,7 @@ except OptionalDependencyNotAvailable: else: _import_structure["modeling_qwen2"] = [ "Qwen2ForCausalLM", + "Qwen2ForQuestionAnswering", "Qwen2Model", "Qwen2PreTrainedModel", "Qwen2ForSequenceClassification", @@ -69,6 +70,7 @@ if TYPE_CHECKING: else: from .modeling_qwen2 import ( Qwen2ForCausalLM, + Qwen2ForQuestionAnswering, Qwen2ForSequenceClassification, Qwen2ForTokenClassification, Qwen2Model, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 50f273ba76..2585352fc9 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, + QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) @@ -1508,3 +1509,103 @@ class Qwen2ForTokenClassification(Qwen2PreTrainedModel): hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings( + """ +The Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + QWEN2_START_DOCSTRING, +) +# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2, MISTRAL->QWEN2 +class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): + base_model_prefix = "model" + + # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen2 + def __init__(self, config): + super().__init__(config) + self.model = Qwen2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/qwen2_moe/__init__.py b/src/transformers/models/qwen2_moe/__init__.py index e2b73ba2d1..9520141ea8 100644 --- a/src/transformers/models/qwen2_moe/__init__.py +++ b/src/transformers/models/qwen2_moe/__init__.py @@ -33,6 +33,7 @@ except OptionalDependencyNotAvailable: else: _import_structure["modeling_qwen2_moe"] = [ "Qwen2MoeForCausalLM", + "Qwen2MoeForQuestionAnswering", "Qwen2MoeModel", "Qwen2MoePreTrainedModel", "Qwen2MoeForSequenceClassification", @@ -51,6 +52,7 @@ if TYPE_CHECKING: else: from .modeling_qwen2_moe import ( Qwen2MoeForCausalLM, + Qwen2MoeForQuestionAnswering, Qwen2MoeForSequenceClassification, Qwen2MoeForTokenClassification, Qwen2MoeModel, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 2ab13b7227..1a5f6e2ff2 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -35,6 +35,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, + QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) @@ -1713,3 +1714,103 @@ class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel): hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings( + """ +The Qwen2MoE Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + QWEN2MOE_START_DOCSTRING, +) +# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2Moe, MISTRAL->QWEN2MOE +class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel): + base_model_prefix = "model" + + # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen2Moe + def __init__(self, config): + super().__init__(config) + self.model = Qwen2MoeModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 048de1cc8a..4ca25bc791 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -5920,6 +5920,13 @@ class MistralForCausalLM(metaclass=DummyObject): requires_backends(self, ["torch"]) +class MistralForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MistralForSequenceClassification(metaclass=DummyObject): _backends = ["torch"] @@ -5955,6 +5962,13 @@ class MixtralForCausalLM(metaclass=DummyObject): requires_backends(self, ["torch"]) +class MixtralForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MixtralForSequenceClassification(metaclass=DummyObject): _backends = ["torch"] @@ -7406,6 +7420,13 @@ class Qwen2ForCausalLM(metaclass=DummyObject): requires_backends(self, ["torch"]) +class Qwen2ForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Qwen2ForSequenceClassification(metaclass=DummyObject): _backends = ["torch"] @@ -7462,6 +7483,13 @@ class Qwen2MoeForCausalLM(metaclass=DummyObject): requires_backends(self, ["torch"]) +class Qwen2MoeForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Qwen2MoeForSequenceClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index c24436d4b8..ff7f1e87bc 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -47,6 +47,7 @@ if is_torch_available(): from transformers import ( MistralForCausalLM, + MistralForQuestionAnswering, MistralForSequenceClassification, MistralForTokenClassification, MistralModel, @@ -291,7 +292,13 @@ class MistralModelTester: @require_torch class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (MistralModel, MistralForCausalLM, MistralForSequenceClassification, MistralForTokenClassification) + ( + MistralModel, + MistralForCausalLM, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralForQuestionAnswering, + ) if is_torch_available() else () ) @@ -303,6 +310,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi "token-classification": MistralForTokenClassification, "text-generation": MistralForCausalLM, "zero-shot": MistralForSequenceClassification, + "question-answering": MistralForQuestionAnswering, } if is_torch_available() else {} diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index a2655bb773..0e6b2a999e 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -41,6 +41,7 @@ if is_torch_available(): from transformers import ( MixtralForCausalLM, + MixtralForQuestionAnswering, MixtralForSequenceClassification, MixtralForTokenClassification, MixtralModel, @@ -291,7 +292,13 @@ class MixtralModelTester: # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification, MixtralForTokenClassification) + ( + MixtralModel, + MixtralForCausalLM, + MixtralForSequenceClassification, + MixtralForTokenClassification, + MixtralForQuestionAnswering, + ) if is_torch_available() else () ) @@ -303,6 +310,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi "token-classification": MixtralForTokenClassification, "text-generation": MixtralForCausalLM, "zero-shot": MixtralForSequenceClassification, + "question-answering": MixtralForQuestionAnswering, } if is_torch_available() else {} diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index debcf42ab3..5e5c42d4c5 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -43,6 +43,7 @@ if is_torch_available(): from transformers import ( Qwen2ForCausalLM, + Qwen2ForQuestionAnswering, Qwen2ForSequenceClassification, Qwen2ForTokenClassification, Qwen2Model, @@ -300,7 +301,13 @@ class Qwen2ModelTester: # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2 class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification, Qwen2ForTokenClassification) + ( + Qwen2Model, + Qwen2ForCausalLM, + Qwen2ForSequenceClassification, + Qwen2ForTokenClassification, + Qwen2ForQuestionAnswering, + ) if is_torch_available() else () ) @@ -312,6 +319,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi "token-classification": Qwen2ForTokenClassification, "text-generation": Qwen2ForCausalLM, "zero-shot": Qwen2ForSequenceClassification, + "question-answering": Qwen2ForQuestionAnswering, } if is_torch_available() else {} diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index 60df825c9b..d7b17b740f 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -43,6 +43,7 @@ if is_torch_available(): from transformers import ( Qwen2MoeForCausalLM, + Qwen2MoeForQuestionAnswering, Qwen2MoeForSequenceClassification, Qwen2MoeForTokenClassification, Qwen2MoeModel, @@ -327,7 +328,13 @@ class Qwen2MoeModelTester: # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, Qwen2MoeForTokenClassification) + ( + Qwen2MoeModel, + Qwen2MoeForCausalLM, + Qwen2MoeForSequenceClassification, + Qwen2MoeForTokenClassification, + Qwen2MoeForQuestionAnswering, + ) if is_torch_available() else () ) @@ -339,6 +346,7 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM "token-classification": Qwen2MoeForTokenClassification, "text-generation": Qwen2MoeForCausalLM, "zero-shot": Qwen2MoeForSequenceClassification, + "question-answering": Qwen2MoeForQuestionAnswering, } if is_torch_available() else {}