Mistral-related models for QnA (#34045)
* mistral qna start * mixtral qna * oops * qwen2 qna * qwen2moe qna * add missing input embed methods * add copied to all methods, can't directly from llama due to the prefix * make top level copied from
This commit is contained in:
@@ -208,6 +208,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||||||
[[autodoc]] MistralForTokenClassification
|
[[autodoc]] MistralForTokenClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## MistralForQuestionAnswering
|
||||||
|
|
||||||
|
[[autodoc]] MistralForQuestionAnswering
|
||||||
|
- forward
|
||||||
|
|
||||||
## FlaxMistralModel
|
## FlaxMistralModel
|
||||||
|
|
||||||
[[autodoc]] FlaxMistralModel
|
[[autodoc]] FlaxMistralModel
|
||||||
|
|||||||
@@ -209,3 +209,7 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||||||
|
|
||||||
[[autodoc]] MixtralForTokenClassification
|
[[autodoc]] MixtralForTokenClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## MixtralForQuestionAnswering
|
||||||
|
[[autodoc]] MixtralForQuestionAnswering
|
||||||
|
- forward
|
||||||
|
|||||||
@@ -85,3 +85,8 @@ In the following, we demonstrate how to use `Qwen2-7B-Instruct` for the inferenc
|
|||||||
|
|
||||||
[[autodoc]] Qwen2ForTokenClassification
|
[[autodoc]] Qwen2ForTokenClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## Qwen2ForQuestionAnswering
|
||||||
|
|
||||||
|
[[autodoc]] Qwen2ForQuestionAnswering
|
||||||
|
- forward
|
||||||
|
|||||||
@@ -80,3 +80,8 @@ In the following, we demonstrate how to use `Qwen1.5-MoE-A2.7B-Chat` for the inf
|
|||||||
|
|
||||||
[[autodoc]] Qwen2MoeForTokenClassification
|
[[autodoc]] Qwen2MoeForTokenClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## Qwen2MoeForQuestionAnswering
|
||||||
|
|
||||||
|
[[autodoc]] Qwen2MoeForQuestionAnswering
|
||||||
|
- forward
|
||||||
|
|||||||
@@ -2709,6 +2709,7 @@ else:
|
|||||||
_import_structure["models.mistral"].extend(
|
_import_structure["models.mistral"].extend(
|
||||||
[
|
[
|
||||||
"MistralForCausalLM",
|
"MistralForCausalLM",
|
||||||
|
"MistralForQuestionAnswering",
|
||||||
"MistralForSequenceClassification",
|
"MistralForSequenceClassification",
|
||||||
"MistralForTokenClassification",
|
"MistralForTokenClassification",
|
||||||
"MistralModel",
|
"MistralModel",
|
||||||
@@ -2718,6 +2719,7 @@ else:
|
|||||||
_import_structure["models.mixtral"].extend(
|
_import_structure["models.mixtral"].extend(
|
||||||
[
|
[
|
||||||
"MixtralForCausalLM",
|
"MixtralForCausalLM",
|
||||||
|
"MixtralForQuestionAnswering",
|
||||||
"MixtralForSequenceClassification",
|
"MixtralForSequenceClassification",
|
||||||
"MixtralForTokenClassification",
|
"MixtralForTokenClassification",
|
||||||
"MixtralModel",
|
"MixtralModel",
|
||||||
@@ -3094,6 +3096,7 @@ else:
|
|||||||
_import_structure["models.qwen2"].extend(
|
_import_structure["models.qwen2"].extend(
|
||||||
[
|
[
|
||||||
"Qwen2ForCausalLM",
|
"Qwen2ForCausalLM",
|
||||||
|
"Qwen2ForQuestionAnswering",
|
||||||
"Qwen2ForSequenceClassification",
|
"Qwen2ForSequenceClassification",
|
||||||
"Qwen2ForTokenClassification",
|
"Qwen2ForTokenClassification",
|
||||||
"Qwen2Model",
|
"Qwen2Model",
|
||||||
@@ -3110,6 +3113,7 @@ else:
|
|||||||
_import_structure["models.qwen2_moe"].extend(
|
_import_structure["models.qwen2_moe"].extend(
|
||||||
[
|
[
|
||||||
"Qwen2MoeForCausalLM",
|
"Qwen2MoeForCausalLM",
|
||||||
|
"Qwen2MoeForQuestionAnswering",
|
||||||
"Qwen2MoeForSequenceClassification",
|
"Qwen2MoeForSequenceClassification",
|
||||||
"Qwen2MoeForTokenClassification",
|
"Qwen2MoeForTokenClassification",
|
||||||
"Qwen2MoeModel",
|
"Qwen2MoeModel",
|
||||||
@@ -7323,6 +7327,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.mistral import (
|
from .models.mistral import (
|
||||||
MistralForCausalLM,
|
MistralForCausalLM,
|
||||||
|
MistralForQuestionAnswering,
|
||||||
MistralForSequenceClassification,
|
MistralForSequenceClassification,
|
||||||
MistralForTokenClassification,
|
MistralForTokenClassification,
|
||||||
MistralModel,
|
MistralModel,
|
||||||
@@ -7330,6 +7335,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.mixtral import (
|
from .models.mixtral import (
|
||||||
MixtralForCausalLM,
|
MixtralForCausalLM,
|
||||||
|
MixtralForQuestionAnswering,
|
||||||
MixtralForSequenceClassification,
|
MixtralForSequenceClassification,
|
||||||
MixtralForTokenClassification,
|
MixtralForTokenClassification,
|
||||||
MixtralModel,
|
MixtralModel,
|
||||||
@@ -7625,6 +7631,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.qwen2 import (
|
from .models.qwen2 import (
|
||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
|
Qwen2ForQuestionAnswering,
|
||||||
Qwen2ForSequenceClassification,
|
Qwen2ForSequenceClassification,
|
||||||
Qwen2ForTokenClassification,
|
Qwen2ForTokenClassification,
|
||||||
Qwen2Model,
|
Qwen2Model,
|
||||||
@@ -7637,6 +7644,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.qwen2_moe import (
|
from .models.qwen2_moe import (
|
||||||
Qwen2MoeForCausalLM,
|
Qwen2MoeForCausalLM,
|
||||||
|
Qwen2MoeForQuestionAnswering,
|
||||||
Qwen2MoeForSequenceClassification,
|
Qwen2MoeForSequenceClassification,
|
||||||
Qwen2MoeForTokenClassification,
|
Qwen2MoeForTokenClassification,
|
||||||
Qwen2MoeModel,
|
Qwen2MoeModel,
|
||||||
|
|||||||
@@ -1046,6 +1046,8 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|||||||
("mbart", "MBartForQuestionAnswering"),
|
("mbart", "MBartForQuestionAnswering"),
|
||||||
("mega", "MegaForQuestionAnswering"),
|
("mega", "MegaForQuestionAnswering"),
|
||||||
("megatron-bert", "MegatronBertForQuestionAnswering"),
|
("megatron-bert", "MegatronBertForQuestionAnswering"),
|
||||||
|
("mistral", "MistralForQuestionAnswering"),
|
||||||
|
("mixtral", "MixtralForQuestionAnswering"),
|
||||||
("mobilebert", "MobileBertForQuestionAnswering"),
|
("mobilebert", "MobileBertForQuestionAnswering"),
|
||||||
("mpnet", "MPNetForQuestionAnswering"),
|
("mpnet", "MPNetForQuestionAnswering"),
|
||||||
("mpt", "MptForQuestionAnswering"),
|
("mpt", "MptForQuestionAnswering"),
|
||||||
@@ -1057,6 +1059,8 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|||||||
("nystromformer", "NystromformerForQuestionAnswering"),
|
("nystromformer", "NystromformerForQuestionAnswering"),
|
||||||
("opt", "OPTForQuestionAnswering"),
|
("opt", "OPTForQuestionAnswering"),
|
||||||
("qdqbert", "QDQBertForQuestionAnswering"),
|
("qdqbert", "QDQBertForQuestionAnswering"),
|
||||||
|
("qwen2", "Qwen2ForQuestionAnswering"),
|
||||||
|
("qwen2_moe", "Qwen2MoeForQuestionAnswering"),
|
||||||
("reformer", "ReformerForQuestionAnswering"),
|
("reformer", "ReformerForQuestionAnswering"),
|
||||||
("rembert", "RemBertForQuestionAnswering"),
|
("rembert", "RemBertForQuestionAnswering"),
|
||||||
("roberta", "RobertaForQuestionAnswering"),
|
("roberta", "RobertaForQuestionAnswering"),
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ except OptionalDependencyNotAvailable:
|
|||||||
else:
|
else:
|
||||||
_import_structure["modeling_mistral"] = [
|
_import_structure["modeling_mistral"] = [
|
||||||
"MistralForCausalLM",
|
"MistralForCausalLM",
|
||||||
|
"MistralForQuestionAnswering",
|
||||||
"MistralModel",
|
"MistralModel",
|
||||||
"MistralPreTrainedModel",
|
"MistralPreTrainedModel",
|
||||||
"MistralForSequenceClassification",
|
"MistralForSequenceClassification",
|
||||||
@@ -78,6 +79,7 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
from .modeling_mistral import (
|
from .modeling_mistral import (
|
||||||
MistralForCausalLM,
|
MistralForCausalLM,
|
||||||
|
MistralForQuestionAnswering,
|
||||||
MistralForSequenceClassification,
|
MistralForSequenceClassification,
|
||||||
MistralForTokenClassification,
|
MistralForTokenClassification,
|
||||||
MistralModel,
|
MistralModel,
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
|
|||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
|
QuestionAnsweringModelOutput,
|
||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
@@ -1408,3 +1409,103 @@ class MistralForTokenClassification(MistralPreTrainedModel):
|
|||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The Mistral Model transformer with a span classification head on top for extractive question-answering tasks like
|
||||||
|
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||||
|
""",
|
||||||
|
MISTRAL_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->Mistral,LLAMA->MISTRAL,transformer->model
|
||||||
|
class MistralForQuestionAnswering(MistralPreTrainedModel):
|
||||||
|
base_model_prefix = "model"
|
||||||
|
|
||||||
|
# Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Mistral
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = MistralModel(config)
|
||||||
|
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||||
|
r"""
|
||||||
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||||
|
are not taken into account for computing the loss.
|
||||||
|
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||||
|
are not taken into account for computing the loss.
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
logits = self.qa_outputs(sequence_output)
|
||||||
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
|
start_logits = start_logits.squeeze(-1).contiguous()
|
||||||
|
end_logits = end_logits.squeeze(-1).contiguous()
|
||||||
|
|
||||||
|
total_loss = None
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
# If we are on multi-GPU, split add a dimension
|
||||||
|
if len(start_positions.size()) > 1:
|
||||||
|
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
||||||
|
if len(end_positions.size()) > 1:
|
||||||
|
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
||||||
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
|
ignored_index = start_logits.size(1)
|
||||||
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
|
total_loss = (start_loss + end_loss) / 2
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (start_logits, end_logits) + outputs[2:]
|
||||||
|
return ((total_loss,) + output) if total_loss is not None else output
|
||||||
|
|
||||||
|
return QuestionAnsweringModelOutput(
|
||||||
|
loss=total_loss,
|
||||||
|
start_logits=start_logits,
|
||||||
|
end_logits=end_logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ except OptionalDependencyNotAvailable:
|
|||||||
else:
|
else:
|
||||||
_import_structure["modeling_mixtral"] = [
|
_import_structure["modeling_mixtral"] = [
|
||||||
"MixtralForCausalLM",
|
"MixtralForCausalLM",
|
||||||
|
"MixtralForQuestionAnswering",
|
||||||
"MixtralModel",
|
"MixtralModel",
|
||||||
"MixtralPreTrainedModel",
|
"MixtralPreTrainedModel",
|
||||||
"MixtralForSequenceClassification",
|
"MixtralForSequenceClassification",
|
||||||
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
from .modeling_mixtral import (
|
from .modeling_mixtral import (
|
||||||
MixtralForCausalLM,
|
MixtralForCausalLM,
|
||||||
|
MixtralForQuestionAnswering,
|
||||||
MixtralForSequenceClassification,
|
MixtralForSequenceClassification,
|
||||||
MixtralForTokenClassification,
|
MixtralForTokenClassification,
|
||||||
MixtralModel,
|
MixtralModel,
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_caus
|
|||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
MoeCausalLMOutputWithPast,
|
MoeCausalLMOutputWithPast,
|
||||||
MoeModelOutputWithPast,
|
MoeModelOutputWithPast,
|
||||||
|
QuestionAnsweringModelOutput,
|
||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
@@ -1644,3 +1645,103 @@ class MixtralForTokenClassification(MixtralPreTrainedModel):
|
|||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The Mixtral Model transformer with a span classification head on top for extractive question-answering tasks like
|
||||||
|
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||||
|
""",
|
||||||
|
MIXTRAL_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Mixtral, MISTRAL->MIXTRAL
|
||||||
|
class MixtralForQuestionAnswering(MixtralPreTrainedModel):
|
||||||
|
base_model_prefix = "model"
|
||||||
|
|
||||||
|
# Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Mixtral
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = MixtralModel(config)
|
||||||
|
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||||
|
r"""
|
||||||
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||||
|
are not taken into account for computing the loss.
|
||||||
|
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||||
|
are not taken into account for computing the loss.
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
logits = self.qa_outputs(sequence_output)
|
||||||
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
|
start_logits = start_logits.squeeze(-1).contiguous()
|
||||||
|
end_logits = end_logits.squeeze(-1).contiguous()
|
||||||
|
|
||||||
|
total_loss = None
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
# If we are on multi-GPU, split add a dimension
|
||||||
|
if len(start_positions.size()) > 1:
|
||||||
|
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
||||||
|
if len(end_positions.size()) > 1:
|
||||||
|
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
||||||
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
|
ignored_index = start_logits.size(1)
|
||||||
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
|
total_loss = (start_loss + end_loss) / 2
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (start_logits, end_logits) + outputs[2:]
|
||||||
|
return ((total_loss,) + output) if total_loss is not None else output
|
||||||
|
|
||||||
|
return QuestionAnsweringModelOutput(
|
||||||
|
loss=total_loss,
|
||||||
|
start_logits=start_logits,
|
||||||
|
end_logits=end_logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ except OptionalDependencyNotAvailable:
|
|||||||
else:
|
else:
|
||||||
_import_structure["modeling_qwen2"] = [
|
_import_structure["modeling_qwen2"] = [
|
||||||
"Qwen2ForCausalLM",
|
"Qwen2ForCausalLM",
|
||||||
|
"Qwen2ForQuestionAnswering",
|
||||||
"Qwen2Model",
|
"Qwen2Model",
|
||||||
"Qwen2PreTrainedModel",
|
"Qwen2PreTrainedModel",
|
||||||
"Qwen2ForSequenceClassification",
|
"Qwen2ForSequenceClassification",
|
||||||
@@ -69,6 +70,7 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
from .modeling_qwen2 import (
|
from .modeling_qwen2 import (
|
||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
|
Qwen2ForQuestionAnswering,
|
||||||
Qwen2ForSequenceClassification,
|
Qwen2ForSequenceClassification,
|
||||||
Qwen2ForTokenClassification,
|
Qwen2ForTokenClassification,
|
||||||
Qwen2Model,
|
Qwen2Model,
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
|
|||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
|
QuestionAnsweringModelOutput,
|
||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
@@ -1508,3 +1509,103 @@ class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
|
|||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like
|
||||||
|
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||||
|
""",
|
||||||
|
QWEN2_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2, MISTRAL->QWEN2
|
||||||
|
class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel):
|
||||||
|
base_model_prefix = "model"
|
||||||
|
|
||||||
|
# Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen2
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = Qwen2Model(config)
|
||||||
|
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||||
|
r"""
|
||||||
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||||
|
are not taken into account for computing the loss.
|
||||||
|
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||||
|
are not taken into account for computing the loss.
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
logits = self.qa_outputs(sequence_output)
|
||||||
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
|
start_logits = start_logits.squeeze(-1).contiguous()
|
||||||
|
end_logits = end_logits.squeeze(-1).contiguous()
|
||||||
|
|
||||||
|
total_loss = None
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
# If we are on multi-GPU, split add a dimension
|
||||||
|
if len(start_positions.size()) > 1:
|
||||||
|
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
||||||
|
if len(end_positions.size()) > 1:
|
||||||
|
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
||||||
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
|
ignored_index = start_logits.size(1)
|
||||||
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
|
total_loss = (start_loss + end_loss) / 2
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (start_logits, end_logits) + outputs[2:]
|
||||||
|
return ((total_loss,) + output) if total_loss is not None else output
|
||||||
|
|
||||||
|
return QuestionAnsweringModelOutput(
|
||||||
|
loss=total_loss,
|
||||||
|
start_logits=start_logits,
|
||||||
|
end_logits=end_logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ except OptionalDependencyNotAvailable:
|
|||||||
else:
|
else:
|
||||||
_import_structure["modeling_qwen2_moe"] = [
|
_import_structure["modeling_qwen2_moe"] = [
|
||||||
"Qwen2MoeForCausalLM",
|
"Qwen2MoeForCausalLM",
|
||||||
|
"Qwen2MoeForQuestionAnswering",
|
||||||
"Qwen2MoeModel",
|
"Qwen2MoeModel",
|
||||||
"Qwen2MoePreTrainedModel",
|
"Qwen2MoePreTrainedModel",
|
||||||
"Qwen2MoeForSequenceClassification",
|
"Qwen2MoeForSequenceClassification",
|
||||||
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
from .modeling_qwen2_moe import (
|
from .modeling_qwen2_moe import (
|
||||||
Qwen2MoeForCausalLM,
|
Qwen2MoeForCausalLM,
|
||||||
|
Qwen2MoeForQuestionAnswering,
|
||||||
Qwen2MoeForSequenceClassification,
|
Qwen2MoeForSequenceClassification,
|
||||||
Qwen2MoeForTokenClassification,
|
Qwen2MoeForTokenClassification,
|
||||||
Qwen2MoeModel,
|
Qwen2MoeModel,
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
|
|||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
MoeCausalLMOutputWithPast,
|
MoeCausalLMOutputWithPast,
|
||||||
MoeModelOutputWithPast,
|
MoeModelOutputWithPast,
|
||||||
|
QuestionAnsweringModelOutput,
|
||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
@@ -1713,3 +1714,103 @@ class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel):
|
|||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The Qwen2MoE Model transformer with a span classification head on top for extractive question-answering tasks like
|
||||||
|
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||||
|
""",
|
||||||
|
QWEN2MOE_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2Moe, MISTRAL->QWEN2MOE
|
||||||
|
class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel):
|
||||||
|
base_model_prefix = "model"
|
||||||
|
|
||||||
|
# Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen2Moe
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = Qwen2MoeModel(config)
|
||||||
|
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||||
|
r"""
|
||||||
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||||
|
are not taken into account for computing the loss.
|
||||||
|
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||||
|
are not taken into account for computing the loss.
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
logits = self.qa_outputs(sequence_output)
|
||||||
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
|
start_logits = start_logits.squeeze(-1).contiguous()
|
||||||
|
end_logits = end_logits.squeeze(-1).contiguous()
|
||||||
|
|
||||||
|
total_loss = None
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
# If we are on multi-GPU, split add a dimension
|
||||||
|
if len(start_positions.size()) > 1:
|
||||||
|
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
||||||
|
if len(end_positions.size()) > 1:
|
||||||
|
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
||||||
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
|
ignored_index = start_logits.size(1)
|
||||||
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
|
total_loss = (start_loss + end_loss) / 2
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (start_logits, end_logits) + outputs[2:]
|
||||||
|
return ((total_loss,) + output) if total_loss is not None else output
|
||||||
|
|
||||||
|
return QuestionAnsweringModelOutput(
|
||||||
|
loss=total_loss,
|
||||||
|
start_logits=start_logits,
|
||||||
|
end_logits=end_logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -5920,6 +5920,13 @@ class MistralForCausalLM(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class MistralForQuestionAnswering(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class MistralForSequenceClassification(metaclass=DummyObject):
|
class MistralForSequenceClassification(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -5955,6 +5962,13 @@ class MixtralForCausalLM(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralForQuestionAnswering(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class MixtralForSequenceClassification(metaclass=DummyObject):
|
class MixtralForSequenceClassification(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -7406,6 +7420,13 @@ class Qwen2ForCausalLM(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2ForQuestionAnswering(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class Qwen2ForSequenceClassification(metaclass=DummyObject):
|
class Qwen2ForSequenceClassification(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -7462,6 +7483,13 @@ class Qwen2MoeForCausalLM(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2MoeForQuestionAnswering(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeForSequenceClassification(metaclass=DummyObject):
|
class Qwen2MoeForSequenceClassification(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
MistralForCausalLM,
|
MistralForCausalLM,
|
||||||
|
MistralForQuestionAnswering,
|
||||||
MistralForSequenceClassification,
|
MistralForSequenceClassification,
|
||||||
MistralForTokenClassification,
|
MistralForTokenClassification,
|
||||||
MistralModel,
|
MistralModel,
|
||||||
@@ -291,7 +292,13 @@ class MistralModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(MistralModel, MistralForCausalLM, MistralForSequenceClassification, MistralForTokenClassification)
|
(
|
||||||
|
MistralModel,
|
||||||
|
MistralForCausalLM,
|
||||||
|
MistralForSequenceClassification,
|
||||||
|
MistralForTokenClassification,
|
||||||
|
MistralForQuestionAnswering,
|
||||||
|
)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
@@ -303,6 +310,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
"token-classification": MistralForTokenClassification,
|
"token-classification": MistralForTokenClassification,
|
||||||
"text-generation": MistralForCausalLM,
|
"text-generation": MistralForCausalLM,
|
||||||
"zero-shot": MistralForSequenceClassification,
|
"zero-shot": MistralForSequenceClassification,
|
||||||
|
"question-answering": MistralForQuestionAnswering,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
MixtralForCausalLM,
|
MixtralForCausalLM,
|
||||||
|
MixtralForQuestionAnswering,
|
||||||
MixtralForSequenceClassification,
|
MixtralForSequenceClassification,
|
||||||
MixtralForTokenClassification,
|
MixtralForTokenClassification,
|
||||||
MixtralModel,
|
MixtralModel,
|
||||||
@@ -291,7 +292,13 @@ class MixtralModelTester:
|
|||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral
|
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral
|
||||||
class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification, MixtralForTokenClassification)
|
(
|
||||||
|
MixtralModel,
|
||||||
|
MixtralForCausalLM,
|
||||||
|
MixtralForSequenceClassification,
|
||||||
|
MixtralForTokenClassification,
|
||||||
|
MixtralForQuestionAnswering,
|
||||||
|
)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
@@ -303,6 +310,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
"token-classification": MixtralForTokenClassification,
|
"token-classification": MixtralForTokenClassification,
|
||||||
"text-generation": MixtralForCausalLM,
|
"text-generation": MixtralForCausalLM,
|
||||||
"zero-shot": MixtralForSequenceClassification,
|
"zero-shot": MixtralForSequenceClassification,
|
||||||
|
"question-answering": MixtralForQuestionAnswering,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
|
Qwen2ForQuestionAnswering,
|
||||||
Qwen2ForSequenceClassification,
|
Qwen2ForSequenceClassification,
|
||||||
Qwen2ForTokenClassification,
|
Qwen2ForTokenClassification,
|
||||||
Qwen2Model,
|
Qwen2Model,
|
||||||
@@ -300,7 +301,13 @@ class Qwen2ModelTester:
|
|||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2
|
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2
|
||||||
class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification, Qwen2ForTokenClassification)
|
(
|
||||||
|
Qwen2Model,
|
||||||
|
Qwen2ForCausalLM,
|
||||||
|
Qwen2ForSequenceClassification,
|
||||||
|
Qwen2ForTokenClassification,
|
||||||
|
Qwen2ForQuestionAnswering,
|
||||||
|
)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
@@ -312,6 +319,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
"token-classification": Qwen2ForTokenClassification,
|
"token-classification": Qwen2ForTokenClassification,
|
||||||
"text-generation": Qwen2ForCausalLM,
|
"text-generation": Qwen2ForCausalLM,
|
||||||
"zero-shot": Qwen2ForSequenceClassification,
|
"zero-shot": Qwen2ForSequenceClassification,
|
||||||
|
"question-answering": Qwen2ForQuestionAnswering,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
Qwen2MoeForCausalLM,
|
Qwen2MoeForCausalLM,
|
||||||
|
Qwen2MoeForQuestionAnswering,
|
||||||
Qwen2MoeForSequenceClassification,
|
Qwen2MoeForSequenceClassification,
|
||||||
Qwen2MoeForTokenClassification,
|
Qwen2MoeForTokenClassification,
|
||||||
Qwen2MoeModel,
|
Qwen2MoeModel,
|
||||||
@@ -327,7 +328,13 @@ class Qwen2MoeModelTester:
|
|||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe
|
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe
|
||||||
class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, Qwen2MoeForTokenClassification)
|
(
|
||||||
|
Qwen2MoeModel,
|
||||||
|
Qwen2MoeForCausalLM,
|
||||||
|
Qwen2MoeForSequenceClassification,
|
||||||
|
Qwen2MoeForTokenClassification,
|
||||||
|
Qwen2MoeForQuestionAnswering,
|
||||||
|
)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
@@ -339,6 +346,7 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
"token-classification": Qwen2MoeForTokenClassification,
|
"token-classification": Qwen2MoeForTokenClassification,
|
||||||
"text-generation": Qwen2MoeForCausalLM,
|
"text-generation": Qwen2MoeForCausalLM,
|
||||||
"zero-shot": Qwen2MoeForSequenceClassification,
|
"zero-shot": Qwen2MoeForSequenceClassification,
|
||||||
|
"question-answering": Qwen2MoeForQuestionAnswering,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
|||||||
Reference in New Issue
Block a user