Add TokenClassification for Mistral, Mixtral and Qwen2 (#29878)

* Add MistralForTokenClassification

* Add tests and docs

* Add token classification for Mixtral and Qwen2

* Save llma for token classification draft

* Add token classification support for Llama, Gemma, Persimmon, StableLm and StarCoder2

* Formatting

* Add token classification support for Qwen2Moe model

* Add dropout layer to each ForTokenClassification model

* Add copied from in tests

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Propagate suggested changes

* Style

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Joseph Enguehard
2024-05-20 09:06:57 +01:00
committed by GitHub
parent 481a957814
commit 07bf2dff78
39 changed files with 1174 additions and 19 deletions

View File

@@ -60,6 +60,11 @@ This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [
[[autodoc]] GemmaForSequenceClassification [[autodoc]] GemmaForSequenceClassification
- forward - forward
## GemmaForTokenClassification
[[autodoc]] GemmaForTokenClassification
- forward
## FlaxGemmaModel ## FlaxGemmaModel
[[autodoc]] FlaxGemmaModel [[autodoc]] FlaxGemmaModel

View File

@@ -121,6 +121,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] LlamaForQuestionAnswering [[autodoc]] LlamaForQuestionAnswering
- forward - forward
## LlamaForTokenClassification
[[autodoc]] LlamaForTokenClassification
- forward
## FlaxLlamaModel ## FlaxLlamaModel
[[autodoc]] FlaxLlamaModel [[autodoc]] FlaxLlamaModel

View File

@@ -203,6 +203,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] MistralForSequenceClassification [[autodoc]] MistralForSequenceClassification
- forward - forward
## MistralForTokenClassification
[[autodoc]] MistralForTokenClassification
- forward
## FlaxMistralModel ## FlaxMistralModel
[[autodoc]] FlaxMistralModel [[autodoc]] FlaxMistralModel

View File

@@ -204,3 +204,8 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] MixtralForSequenceClassification [[autodoc]] MixtralForSequenceClassification
- forward - forward
## MixtralForTokenClassification
[[autodoc]] MixtralForTokenClassification
- forward

View File

@@ -96,3 +96,8 @@ The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece. T
[[autodoc]] PersimmonForSequenceClassification [[autodoc]] PersimmonForSequenceClassification
- forward - forward
## PersimmonForTokenClassification
[[autodoc]] PersimmonForTokenClassification
- forward

View File

@@ -80,3 +80,8 @@ In the following, we demonstrate how to use `Qwen2-7B-Chat-beta` for the inferen
[[autodoc]] Qwen2ForSequenceClassification [[autodoc]] Qwen2ForSequenceClassification
- forward - forward
## Qwen2ForTokenClassification
[[autodoc]] Qwen2ForTokenClassification
- forward

View File

@@ -75,3 +75,8 @@ In the following, we demonstrate how to use `Qwen1.5-MoE-A2.7B-Chat` for the inf
[[autodoc]] Qwen2MoeForSequenceClassification [[autodoc]] Qwen2MoeForSequenceClassification
- forward - forward
## Qwen2MoeForTokenClassification
[[autodoc]] Qwen2MoeForTokenClassification
- forward

View File

@@ -104,3 +104,8 @@ Now, to run the model with Flash Attention 2, refer to the snippet below:
[[autodoc]] StableLmForSequenceClassification [[autodoc]] StableLmForSequenceClassification
- forward - forward
## StableLmForTokenClassification
[[autodoc]] StableLmForTokenClassification
- forward

View File

@@ -66,3 +66,8 @@ These ready-to-use checkpoints can be downloaded and used via the HuggingFace Hu
[[autodoc]] Starcoder2ForSequenceClassification [[autodoc]] Starcoder2ForSequenceClassification
- forward - forward
## Starcoder2ForTokenClassification
[[autodoc]] Starcoder2ForTokenClassification
- forward

View File

@@ -2031,6 +2031,7 @@ else:
[ [
"GemmaForCausalLM", "GemmaForCausalLM",
"GemmaForSequenceClassification", "GemmaForSequenceClassification",
"GemmaForTokenClassification",
"GemmaModel", "GemmaModel",
"GemmaPreTrainedModel", "GemmaPreTrainedModel",
] ]
@@ -2288,6 +2289,7 @@ else:
"LlamaForCausalLM", "LlamaForCausalLM",
"LlamaForQuestionAnswering", "LlamaForQuestionAnswering",
"LlamaForSequenceClassification", "LlamaForSequenceClassification",
"LlamaForTokenClassification",
"LlamaModel", "LlamaModel",
"LlamaPreTrainedModel", "LlamaPreTrainedModel",
] ]
@@ -2435,12 +2437,19 @@ else:
[ [
"MistralForCausalLM", "MistralForCausalLM",
"MistralForSequenceClassification", "MistralForSequenceClassification",
"MistralForTokenClassification",
"MistralModel", "MistralModel",
"MistralPreTrainedModel", "MistralPreTrainedModel",
] ]
) )
_import_structure["models.mixtral"].extend( _import_structure["models.mixtral"].extend(
["MixtralForCausalLM", "MixtralForSequenceClassification", "MixtralModel", "MixtralPreTrainedModel"] [
"MixtralForCausalLM",
"MixtralForSequenceClassification",
"MixtralForTokenClassification",
"MixtralModel",
"MixtralPreTrainedModel",
]
) )
_import_structure["models.mobilebert"].extend( _import_structure["models.mobilebert"].extend(
[ [
@@ -2714,6 +2723,7 @@ else:
[ [
"PersimmonForCausalLM", "PersimmonForCausalLM",
"PersimmonForSequenceClassification", "PersimmonForSequenceClassification",
"PersimmonForTokenClassification",
"PersimmonModel", "PersimmonModel",
"PersimmonPreTrainedModel", "PersimmonPreTrainedModel",
] ]
@@ -2810,6 +2820,7 @@ else:
[ [
"Qwen2ForCausalLM", "Qwen2ForCausalLM",
"Qwen2ForSequenceClassification", "Qwen2ForSequenceClassification",
"Qwen2ForTokenClassification",
"Qwen2Model", "Qwen2Model",
"Qwen2PreTrainedModel", "Qwen2PreTrainedModel",
] ]
@@ -2818,6 +2829,7 @@ else:
[ [
"Qwen2MoeForCausalLM", "Qwen2MoeForCausalLM",
"Qwen2MoeForSequenceClassification", "Qwen2MoeForSequenceClassification",
"Qwen2MoeForTokenClassification",
"Qwen2MoeModel", "Qwen2MoeModel",
"Qwen2MoePreTrainedModel", "Qwen2MoePreTrainedModel",
] ]
@@ -3066,6 +3078,7 @@ else:
[ [
"StableLmForCausalLM", "StableLmForCausalLM",
"StableLmForSequenceClassification", "StableLmForSequenceClassification",
"StableLmForTokenClassification",
"StableLmModel", "StableLmModel",
"StableLmPreTrainedModel", "StableLmPreTrainedModel",
] ]
@@ -3074,6 +3087,7 @@ else:
[ [
"Starcoder2ForCausalLM", "Starcoder2ForCausalLM",
"Starcoder2ForSequenceClassification", "Starcoder2ForSequenceClassification",
"Starcoder2ForTokenClassification",
"Starcoder2Model", "Starcoder2Model",
"Starcoder2PreTrainedModel", "Starcoder2PreTrainedModel",
] ]
@@ -6489,6 +6503,7 @@ if TYPE_CHECKING:
from .models.gemma import ( from .models.gemma import (
GemmaForCausalLM, GemmaForCausalLM,
GemmaForSequenceClassification, GemmaForSequenceClassification,
GemmaForTokenClassification,
GemmaModel, GemmaModel,
GemmaPreTrainedModel, GemmaPreTrainedModel,
) )
@@ -6686,6 +6701,7 @@ if TYPE_CHECKING:
LlamaForCausalLM, LlamaForCausalLM,
LlamaForQuestionAnswering, LlamaForQuestionAnswering,
LlamaForSequenceClassification, LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel, LlamaModel,
LlamaPreTrainedModel, LlamaPreTrainedModel,
) )
@@ -6801,12 +6817,14 @@ if TYPE_CHECKING:
from .models.mistral import ( from .models.mistral import (
MistralForCausalLM, MistralForCausalLM,
MistralForSequenceClassification, MistralForSequenceClassification,
MistralForTokenClassification,
MistralModel, MistralModel,
MistralPreTrainedModel, MistralPreTrainedModel,
) )
from .models.mixtral import ( from .models.mixtral import (
MixtralForCausalLM, MixtralForCausalLM,
MixtralForSequenceClassification, MixtralForSequenceClassification,
MixtralForTokenClassification,
MixtralModel, MixtralModel,
MixtralPreTrainedModel, MixtralPreTrainedModel,
) )
@@ -7025,6 +7043,7 @@ if TYPE_CHECKING:
from .models.persimmon import ( from .models.persimmon import (
PersimmonForCausalLM, PersimmonForCausalLM,
PersimmonForSequenceClassification, PersimmonForSequenceClassification,
PersimmonForTokenClassification,
PersimmonModel, PersimmonModel,
PersimmonPreTrainedModel, PersimmonPreTrainedModel,
) )
@@ -7099,12 +7118,14 @@ if TYPE_CHECKING:
from .models.qwen2 import ( from .models.qwen2 import (
Qwen2ForCausalLM, Qwen2ForCausalLM,
Qwen2ForSequenceClassification, Qwen2ForSequenceClassification,
Qwen2ForTokenClassification,
Qwen2Model, Qwen2Model,
Qwen2PreTrainedModel, Qwen2PreTrainedModel,
) )
from .models.qwen2_moe import ( from .models.qwen2_moe import (
Qwen2MoeForCausalLM, Qwen2MoeForCausalLM,
Qwen2MoeForSequenceClassification, Qwen2MoeForSequenceClassification,
Qwen2MoeForTokenClassification,
Qwen2MoeModel, Qwen2MoeModel,
Qwen2MoePreTrainedModel, Qwen2MoePreTrainedModel,
) )
@@ -7306,12 +7327,14 @@ if TYPE_CHECKING:
from .models.stablelm import ( from .models.stablelm import (
StableLmForCausalLM, StableLmForCausalLM,
StableLmForSequenceClassification, StableLmForSequenceClassification,
StableLmForTokenClassification,
StableLmModel, StableLmModel,
StableLmPreTrainedModel, StableLmPreTrainedModel,
) )
from .models.starcoder2 import ( from .models.starcoder2 import (
Starcoder2ForCausalLM, Starcoder2ForCausalLM,
Starcoder2ForSequenceClassification, Starcoder2ForSequenceClassification,
Starcoder2ForTokenClassification,
Starcoder2Model, Starcoder2Model,
Starcoder2PreTrainedModel, Starcoder2PreTrainedModel,
) )

View File

@@ -1038,6 +1038,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("flaubert", "FlaubertForTokenClassification"), ("flaubert", "FlaubertForTokenClassification"),
("fnet", "FNetForTokenClassification"), ("fnet", "FNetForTokenClassification"),
("funnel", "FunnelForTokenClassification"), ("funnel", "FunnelForTokenClassification"),
("gemma", "GemmaForTokenClassification"),
("gpt-sw3", "GPT2ForTokenClassification"), ("gpt-sw3", "GPT2ForTokenClassification"),
("gpt2", "GPT2ForTokenClassification"), ("gpt2", "GPT2ForTokenClassification"),
("gpt_bigcode", "GPTBigCodeForTokenClassification"), ("gpt_bigcode", "GPTBigCodeForTokenClassification"),
@@ -1048,11 +1049,14 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("layoutlmv2", "LayoutLMv2ForTokenClassification"), ("layoutlmv2", "LayoutLMv2ForTokenClassification"),
("layoutlmv3", "LayoutLMv3ForTokenClassification"), ("layoutlmv3", "LayoutLMv3ForTokenClassification"),
("lilt", "LiltForTokenClassification"), ("lilt", "LiltForTokenClassification"),
("llama", "LlamaForTokenClassification"),
("longformer", "LongformerForTokenClassification"), ("longformer", "LongformerForTokenClassification"),
("luke", "LukeForTokenClassification"), ("luke", "LukeForTokenClassification"),
("markuplm", "MarkupLMForTokenClassification"), ("markuplm", "MarkupLMForTokenClassification"),
("mega", "MegaForTokenClassification"), ("mega", "MegaForTokenClassification"),
("megatron-bert", "MegatronBertForTokenClassification"), ("megatron-bert", "MegatronBertForTokenClassification"),
("mistral", "MistralForTokenClassification"),
("mixtral", "MixtralForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"),
("mpnet", "MPNetForTokenClassification"), ("mpnet", "MPNetForTokenClassification"),
("mpt", "MptForTokenClassification"), ("mpt", "MptForTokenClassification"),
@@ -1060,15 +1064,20 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("mt5", "MT5ForTokenClassification"), ("mt5", "MT5ForTokenClassification"),
("nezha", "NezhaForTokenClassification"), ("nezha", "NezhaForTokenClassification"),
("nystromformer", "NystromformerForTokenClassification"), ("nystromformer", "NystromformerForTokenClassification"),
("persimmon", "PersimmonForTokenClassification"),
("phi", "PhiForTokenClassification"), ("phi", "PhiForTokenClassification"),
("phi3", "Phi3ForTokenClassification"), ("phi3", "Phi3ForTokenClassification"),
("qdqbert", "QDQBertForTokenClassification"), ("qdqbert", "QDQBertForTokenClassification"),
("qwen2", "Qwen2ForTokenClassification"),
("qwen2_moe", "Qwen2MoeForTokenClassification"),
("rembert", "RemBertForTokenClassification"), ("rembert", "RemBertForTokenClassification"),
("roberta", "RobertaForTokenClassification"), ("roberta", "RobertaForTokenClassification"),
("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"), ("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"),
("roc_bert", "RoCBertForTokenClassification"), ("roc_bert", "RoCBertForTokenClassification"),
("roformer", "RoFormerForTokenClassification"), ("roformer", "RoFormerForTokenClassification"),
("squeezebert", "SqueezeBertForTokenClassification"), ("squeezebert", "SqueezeBertForTokenClassification"),
("stablelm", "StableLmForTokenClassification"),
("starcoder2", "Starcoder2ForTokenClassification"),
("t5", "T5ForTokenClassification"), ("t5", "T5ForTokenClassification"),
("umt5", "UMT5ForTokenClassification"), ("umt5", "UMT5ForTokenClassification"),
("xlm", "XLMForTokenClassification"), ("xlm", "XLMForTokenClassification"),

View File

@@ -55,6 +55,7 @@ else:
"GemmaModel", "GemmaModel",
"GemmaPreTrainedModel", "GemmaPreTrainedModel",
"GemmaForSequenceClassification", "GemmaForSequenceClassification",
"GemmaForTokenClassification",
] ]
try: try:
@@ -98,6 +99,7 @@ if TYPE_CHECKING:
from .modeling_gemma import ( from .modeling_gemma import (
GemmaForCausalLM, GemmaForCausalLM,
GemmaForSequenceClassification, GemmaForSequenceClassification,
GemmaForTokenClassification,
GemmaModel, GemmaModel,
GemmaPreTrainedModel, GemmaPreTrainedModel,
) )

View File

@@ -30,7 +30,12 @@ from ...modeling_attn_mask_utils import (
AttentionMaskConverter, AttentionMaskConverter,
_prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask,
) )
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from ...utils import ( from ...utils import (
@@ -1346,3 +1351,88 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
The Gemma Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
GEMMA_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Gemma, LLAMA->GEMMA
class GemmaForTokenClassification(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = GemmaModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -55,6 +55,7 @@ else:
"LlamaPreTrainedModel", "LlamaPreTrainedModel",
"LlamaForSequenceClassification", "LlamaForSequenceClassification",
"LlamaForQuestionAnswering", "LlamaForQuestionAnswering",
"LlamaForTokenClassification",
] ]
try: try:
@@ -95,6 +96,7 @@ if TYPE_CHECKING:
LlamaForCausalLM, LlamaForCausalLM,
LlamaForQuestionAnswering, LlamaForQuestionAnswering,
LlamaForSequenceClassification, LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel, LlamaModel,
LlamaPreTrainedModel, LlamaPreTrainedModel,
) )

View File

@@ -36,6 +36,7 @@ from ...modeling_outputs import (
CausalLMOutputWithPast, CausalLMOutputWithPast,
QuestionAnsweringModelOutput, QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...pytorch_utils import ALL_LAYERNORM_LAYERS
@@ -1516,3 +1517,87 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
@add_start_docstrings(
"""
The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
LLAMA_START_DOCSTRING,
)
class LlamaForTokenClassification(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = LlamaModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -32,6 +32,7 @@ else:
"MistralModel", "MistralModel",
"MistralPreTrainedModel", "MistralPreTrainedModel",
"MistralForSequenceClassification", "MistralForSequenceClassification",
"MistralForTokenClassification",
] ]
try: try:
@@ -59,6 +60,7 @@ if TYPE_CHECKING:
from .modeling_mistral import ( from .modeling_mistral import (
MistralForCausalLM, MistralForCausalLM,
MistralForSequenceClassification, MistralForSequenceClassification,
MistralForTokenClassification,
MistralModel, MistralModel,
MistralPreTrainedModel, MistralPreTrainedModel,
) )

View File

@@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
@@ -1366,3 +1371,88 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
MISTRAL_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mistral, LLAMA->MISTRAL
class MistralForTokenClassification(MistralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = MistralModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -36,6 +36,7 @@ else:
"MixtralModel", "MixtralModel",
"MixtralPreTrainedModel", "MixtralPreTrainedModel",
"MixtralForSequenceClassification", "MixtralForSequenceClassification",
"MixtralForTokenClassification",
] ]
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
from .modeling_mixtral import ( from .modeling_mixtral import (
MixtralForCausalLM, MixtralForCausalLM,
MixtralForSequenceClassification, MixtralForSequenceClassification,
MixtralForTokenClassification,
MixtralModel, MixtralModel,
MixtralPreTrainedModel, MixtralPreTrainedModel,
) )

View File

@@ -38,6 +38,7 @@ from ...modeling_outputs import (
MoeCausalLMOutputWithPast, MoeCausalLMOutputWithPast,
MoeModelOutputWithPast, MoeModelOutputWithPast,
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13 from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
@@ -1582,3 +1583,88 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
The Mixtral Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
MIXTRAL_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL
class MixtralForTokenClassification(MixtralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = MixtralModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -36,6 +36,7 @@ else:
"PersimmonModel", "PersimmonModel",
"PersimmonPreTrainedModel", "PersimmonPreTrainedModel",
"PersimmonForSequenceClassification", "PersimmonForSequenceClassification",
"PersimmonForTokenClassification",
] ]
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
from .modeling_persimmon import ( from .modeling_persimmon import (
PersimmonForCausalLM, PersimmonForCausalLM,
PersimmonForSequenceClassification, PersimmonForSequenceClassification,
PersimmonForTokenClassification,
PersimmonModel, PersimmonModel,
PersimmonPreTrainedModel, PersimmonPreTrainedModel,
) )

View File

@@ -29,7 +29,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_persimmon import PersimmonConfig from .configuration_persimmon import PersimmonConfig
@@ -1011,3 +1016,88 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
The Persimmon Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
PERSIMMON_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Persimmon, LLAMA->PERSIMMON
class PersimmonForTokenClassification(PersimmonPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = PersimmonModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -45,6 +45,7 @@ else:
"Qwen2Model", "Qwen2Model",
"Qwen2PreTrainedModel", "Qwen2PreTrainedModel",
"Qwen2ForSequenceClassification", "Qwen2ForSequenceClassification",
"Qwen2ForTokenClassification",
] ]
@@ -69,6 +70,7 @@ if TYPE_CHECKING:
from .modeling_qwen2 import ( from .modeling_qwen2 import (
Qwen2ForCausalLM, Qwen2ForCausalLM,
Qwen2ForSequenceClassification, Qwen2ForSequenceClassification,
Qwen2ForTokenClassification,
Qwen2Model, Qwen2Model,
Qwen2PreTrainedModel, Qwen2PreTrainedModel,
) )

View File

@@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
@@ -1375,3 +1380,88 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
QWEN2_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2
class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Qwen2Model(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -36,6 +36,7 @@ else:
"Qwen2MoeModel", "Qwen2MoeModel",
"Qwen2MoePreTrainedModel", "Qwen2MoePreTrainedModel",
"Qwen2MoeForSequenceClassification", "Qwen2MoeForSequenceClassification",
"Qwen2MoeForTokenClassification",
] ]
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
from .modeling_qwen2_moe import ( from .modeling_qwen2_moe import (
Qwen2MoeForCausalLM, Qwen2MoeForCausalLM,
Qwen2MoeForSequenceClassification, Qwen2MoeForSequenceClassification,
Qwen2MoeForTokenClassification,
Qwen2MoeModel, Qwen2MoeModel,
Qwen2MoePreTrainedModel, Qwen2MoePreTrainedModel,
) )

View File

@@ -32,7 +32,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import (
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
@@ -1571,3 +1576,88 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
The Qwen2MoE Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
QWEN2MOE_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE
class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Qwen2MoeModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -36,6 +36,7 @@ else:
"StableLmModel", "StableLmModel",
"StableLmPreTrainedModel", "StableLmPreTrainedModel",
"StableLmForSequenceClassification", "StableLmForSequenceClassification",
"StableLmForTokenClassification",
] ]
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
from .modeling_stablelm import ( from .modeling_stablelm import (
StableLmForCausalLM, StableLmForCausalLM,
StableLmForSequenceClassification, StableLmForSequenceClassification,
StableLmForTokenClassification,
StableLmModel, StableLmModel,
StableLmPreTrainedModel, StableLmPreTrainedModel,
) )

View File

@@ -30,7 +30,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
@@ -1383,3 +1388,88 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
The StableLm Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
STABLELM_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->StableLm, LLAMA->STABLELM
class StableLmForTokenClassification(StableLmPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = StableLmModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -36,6 +36,7 @@ else:
"Starcoder2Model", "Starcoder2Model",
"Starcoder2PreTrainedModel", "Starcoder2PreTrainedModel",
"Starcoder2ForSequenceClassification", "Starcoder2ForSequenceClassification",
"Starcoder2ForTokenClassification",
] ]
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
from .modeling_starcoder2 import ( from .modeling_starcoder2 import (
Starcoder2ForCausalLM, Starcoder2ForCausalLM,
Starcoder2ForSequenceClassification, Starcoder2ForSequenceClassification,
Starcoder2ForTokenClassification,
Starcoder2Model, Starcoder2Model,
Starcoder2PreTrainedModel, Starcoder2PreTrainedModel,
) )

View File

@@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
@@ -1357,3 +1362,88 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@add_start_docstrings(
"""
The Starcoder2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
STARCODER2_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Starcoder2, LLAMA->STARCODER2
class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Starcoder2Model(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -3692,6 +3692,13 @@ class GemmaForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class GemmaForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GemmaModel(metaclass=DummyObject): class GemmaModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
@@ -4642,6 +4649,13 @@ class LlamaForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class LlamaForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LlamaModel(metaclass=DummyObject): class LlamaModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
@@ -5237,6 +5251,13 @@ class MistralForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class MistralForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MistralModel(metaclass=DummyObject): class MistralModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
@@ -5265,6 +5286,13 @@ class MixtralForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class MixtralForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MixtralModel(metaclass=DummyObject): class MixtralModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
@@ -6373,6 +6401,13 @@ class PersimmonForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class PersimmonForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PersimmonModel(metaclass=DummyObject): class PersimmonModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
@@ -6734,6 +6769,13 @@ class Qwen2ForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Qwen2ForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Qwen2Model(metaclass=DummyObject): class Qwen2Model(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
@@ -6762,6 +6804,13 @@ class Qwen2MoeForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Qwen2MoeForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Qwen2MoeModel(metaclass=DummyObject): class Qwen2MoeModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
@@ -7793,6 +7842,13 @@ class StableLmForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class StableLmForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class StableLmModel(metaclass=DummyObject): class StableLmModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
@@ -7821,6 +7877,13 @@ class Starcoder2ForSequenceClassification(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Starcoder2ForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Starcoder2Model(metaclass=DummyObject): class Starcoder2Model(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]

View File

@@ -41,7 +41,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel, GemmaTokenizer from transformers import (
GemmaForCausalLM,
GemmaForSequenceClassification,
GemmaForTokenClassification,
GemmaModel,
GemmaTokenizer,
)
class GemmaModelTester: class GemmaModelTester:
@@ -284,12 +290,17 @@ class GemmaModelTester:
@require_torch @require_torch
class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification) if is_torch_available() else () all_model_classes = (
(GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification)
if is_torch_available()
else ()
)
all_generative_model_classes = (GemmaForCausalLM,) if is_torch_available() else () all_generative_model_classes = (GemmaForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": GemmaModel, "feature-extraction": GemmaModel,
"text-classification": GemmaForSequenceClassification, "text-classification": GemmaForSequenceClassification,
"token-classification": GemmaForTokenClassification,
"text-generation": GemmaForCausalLM, "text-generation": GemmaForCausalLM,
"zero-shot": GemmaForSequenceClassification, "zero-shot": GemmaForSequenceClassification,
} }
@@ -370,6 +381,22 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Gemma,llama->Gemma
def test_Gemma_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = GemmaForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Gemma buffers include complex numbers, which breaks this test") @unittest.skip("Gemma buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass

View File

@@ -47,6 +47,7 @@ if is_torch_available():
LlamaForCausalLM, LlamaForCausalLM,
LlamaForQuestionAnswering, LlamaForQuestionAnswering,
LlamaForSequenceClassification, LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel, LlamaModel,
LlamaTokenizer, LlamaTokenizer,
) )
@@ -286,7 +287,13 @@ class LlamaModelTester:
@require_torch @require_torch
class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForQuestionAnswering) (
LlamaModel,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaForQuestionAnswering,
LlamaForTokenClassification,
)
if is_torch_available() if is_torch_available()
else () else ()
) )
@@ -298,6 +305,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
"text-generation": LlamaForCausalLM, "text-generation": LlamaForCausalLM,
"zero-shot": LlamaForSequenceClassification, "zero-shot": LlamaForSequenceClassification,
"question-answering": LlamaForQuestionAnswering, "question-answering": LlamaForQuestionAnswering,
"token-classification": LlamaForTokenClassification,
} }
if is_torch_available() if is_torch_available()
else {} else {}
@@ -370,6 +378,21 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def test_llama_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = LlamaForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Llama buffers include complex numbers, which breaks this test") @unittest.skip("Llama buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass

View File

@@ -46,6 +46,7 @@ if is_torch_available():
from transformers import ( from transformers import (
MistralForCausalLM, MistralForCausalLM,
MistralForSequenceClassification, MistralForSequenceClassification,
MistralForTokenClassification,
MistralModel, MistralModel,
) )
@@ -288,13 +289,16 @@ class MistralModelTester:
@require_torch @require_torch
class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(MistralModel, MistralForCausalLM, MistralForSequenceClassification) if is_torch_available() else () (MistralModel, MistralForCausalLM, MistralForSequenceClassification, MistralForTokenClassification)
if is_torch_available()
else ()
) )
all_generative_model_classes = (MistralForCausalLM,) if is_torch_available() else () all_generative_model_classes = (MistralForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": MistralModel, "feature-extraction": MistralModel,
"text-classification": MistralForSequenceClassification, "text-classification": MistralForSequenceClassification,
"token-classification": MistralForTokenClassification,
"text-generation": MistralForCausalLM, "text-generation": MistralForCausalLM,
"zero-shot": MistralForSequenceClassification, "zero-shot": MistralForSequenceClassification,
} }
@@ -376,6 +380,22 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mistral,llama->Mistral
def test_Mistral_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = MistralForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Mistral buffers include complex numbers, which breaks this test") @unittest.skip("Mistral buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass

View File

@@ -40,7 +40,12 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel from transformers import (
MixtralForCausalLM,
MixtralForSequenceClassification,
MixtralForTokenClassification,
MixtralModel,
)
class MixtralModelTester: class MixtralModelTester:
@@ -287,13 +292,16 @@ class MixtralModelTester:
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral
class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification) if is_torch_available() else () (MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification, MixtralForTokenClassification)
if is_torch_available()
else ()
) )
all_generative_model_classes = (MixtralForCausalLM,) if is_torch_available() else () all_generative_model_classes = (MixtralForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": MixtralModel, "feature-extraction": MixtralModel,
"text-classification": MixtralForSequenceClassification, "text-classification": MixtralForSequenceClassification,
"token-classification": MixtralForTokenClassification,
"text-generation": MixtralForCausalLM, "text-generation": MixtralForCausalLM,
"zero-shot": MixtralForSequenceClassification, "zero-shot": MixtralForSequenceClassification,
} }
@@ -375,6 +383,22 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mixtral,llama->Mixtral
def test_Mixtral_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = MixtralForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Mixtral buffers include complex numbers, which breaks this test") @unittest.skip("Mixtral buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass

View File

@@ -44,6 +44,7 @@ if is_torch_available():
AutoTokenizer, AutoTokenizer,
PersimmonForCausalLM, PersimmonForCausalLM,
PersimmonForSequenceClassification, PersimmonForSequenceClassification,
PersimmonForTokenClassification,
PersimmonModel, PersimmonModel,
) )
from transformers.models.persimmon.modeling_persimmon import ( from transformers.models.persimmon.modeling_persimmon import (
@@ -283,12 +284,15 @@ class PersimmonModelTester:
@require_torch @require_torch
class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification) if is_torch_available() else () (PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification, PersimmonForTokenClassification)
if is_torch_available()
else ()
) )
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": PersimmonModel, "feature-extraction": PersimmonModel,
"text-classification": PersimmonForSequenceClassification, "text-classification": PersimmonForSequenceClassification,
"token-classification": PersimmonForTokenClassification,
# TODO (ydshieh): check why these two fail. Fix them or skip them in a better way. # TODO (ydshieh): check why these two fail. Fix them or skip them in a better way.
# "text-generation": PersimmonForCausalLM, # "text-generation": PersimmonForCausalLM,
# "zero-shot": PersimmonForSequenceClassification, # "zero-shot": PersimmonForSequenceClassification,
@@ -365,6 +369,22 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Persimmon,llama->persimmon
def test_persimmon_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = PersimmonForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Persimmon buffers include complex numbers, which breaks this test") @unittest.skip("Persimmon buffers include complex numbers, which breaks this test")
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):

View File

@@ -45,6 +45,7 @@ if is_torch_available():
from transformers import ( from transformers import (
Qwen2ForCausalLM, Qwen2ForCausalLM,
Qwen2ForSequenceClassification, Qwen2ForSequenceClassification,
Qwen2ForTokenClassification,
Qwen2Model, Qwen2Model,
) )
@@ -299,12 +300,17 @@ class Qwen2ModelTester:
@require_torch @require_torch
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2 # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2
class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification) if is_torch_available() else () all_model_classes = (
(Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification, Qwen2ForTokenClassification)
if is_torch_available()
else ()
)
all_generative_model_classes = (Qwen2ForCausalLM,) if is_torch_available() else () all_generative_model_classes = (Qwen2ForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": Qwen2Model, "feature-extraction": Qwen2Model,
"text-classification": Qwen2ForSequenceClassification, "text-classification": Qwen2ForSequenceClassification,
"token-classification": Qwen2ForTokenClassification,
"text-generation": Qwen2ForCausalLM, "text-generation": Qwen2ForCausalLM,
"zero-shot": Qwen2ForSequenceClassification, "zero-shot": Qwen2ForSequenceClassification,
} }
@@ -387,6 +393,22 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2,llama->Qwen2
def test_Qwen2_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = Qwen2ForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Qwen2 buffers include complex numbers, which breaks this test") @unittest.skip("Qwen2 buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass

View File

@@ -45,6 +45,7 @@ if is_torch_available():
from transformers import ( from transformers import (
Qwen2MoeForCausalLM, Qwen2MoeForCausalLM,
Qwen2MoeForSequenceClassification, Qwen2MoeForSequenceClassification,
Qwen2MoeForTokenClassification,
Qwen2MoeModel, Qwen2MoeModel,
) )
@@ -327,13 +328,16 @@ class Qwen2MoeModelTester:
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe
class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification) if is_torch_available() else () (Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, Qwen2MoeForTokenClassification)
if is_torch_available()
else ()
) )
all_generative_model_classes = (Qwen2MoeForCausalLM,) if is_torch_available() else () all_generative_model_classes = (Qwen2MoeForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": Qwen2MoeModel, "feature-extraction": Qwen2MoeModel,
"text-classification": Qwen2MoeForSequenceClassification, "text-classification": Qwen2MoeForSequenceClassification,
"token-classification": Qwen2MoeForTokenClassification,
"text-generation": Qwen2MoeForCausalLM, "text-generation": Qwen2MoeForCausalLM,
"zero-shot": Qwen2MoeForSequenceClassification, "zero-shot": Qwen2MoeForSequenceClassification,
} }
@@ -414,6 +418,22 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2Moe,llama->Qwen2Moe
def test_Qwen2Moe_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = Qwen2MoeForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Qwen2Moe buffers include complex numbers, which breaks this test") @unittest.skip("Qwen2Moe buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass

View File

@@ -43,6 +43,7 @@ if is_torch_available():
AutoTokenizer, AutoTokenizer,
StableLmForCausalLM, StableLmForCausalLM,
StableLmForSequenceClassification, StableLmForSequenceClassification,
StableLmForTokenClassification,
StableLmModel, StableLmModel,
) )
from transformers.models.stablelm.modeling_stablelm import ( from transformers.models.stablelm.modeling_stablelm import (
@@ -287,12 +288,15 @@ class StableLmModelTester:
# Copied from transformers.tests.persimmon.test_modeling_persimmon.PersimmonModelTest with Persimmon -> StableLm # Copied from transformers.tests.persimmon.test_modeling_persimmon.PersimmonModelTest with Persimmon -> StableLm
class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification) if is_torch_available() else () (StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification, StableLmForTokenClassification)
if is_torch_available()
else ()
) )
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": StableLmModel, "feature-extraction": StableLmModel,
"text-classification": StableLmForSequenceClassification, "text-classification": StableLmForSequenceClassification,
"token-classification": StableLmForTokenClassification,
# TODO (ydshieh): check why these two fail. Fix them or skip them in a better way. # TODO (ydshieh): check why these two fail. Fix them or skip them in a better way.
# "text-generation": StableLmForCausalLM, # "text-generation": StableLmForCausalLM,
# "zero-shot": StableLmForSequenceClassification, # "zero-shot": StableLmForSequenceClassification,
@@ -356,6 +360,22 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->StableLm,llama->stablelm
def test_stablelm_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = StableLmForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@parameterized.expand([("linear",), ("dynamic",)]) @parameterized.expand([("linear",), ("dynamic",)])
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm
def test_model_rope_scaling_from_config(self, scaling_type): def test_model_rope_scaling_from_config(self, scaling_type):

View File

@@ -43,6 +43,7 @@ if is_torch_available():
AutoTokenizer, AutoTokenizer,
Starcoder2ForCausalLM, Starcoder2ForCausalLM,
Starcoder2ForSequenceClassification, Starcoder2ForSequenceClassification,
Starcoder2ForTokenClassification,
Starcoder2Model, Starcoder2Model,
) )
@@ -290,13 +291,16 @@ class Starcoder2ModelTester:
# Copied from transformers.tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Starcoder2 # Copied from transformers.tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Starcoder2
class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification) if is_torch_available() else () (Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, Starcoder2ForTokenClassification)
if is_torch_available()
else ()
) )
all_generative_model_classes = (Starcoder2ForCausalLM,) if is_torch_available() else () all_generative_model_classes = (Starcoder2ForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
"feature-extraction": Starcoder2Model, "feature-extraction": Starcoder2Model,
"text-classification": Starcoder2ForSequenceClassification, "text-classification": Starcoder2ForSequenceClassification,
"token-classification": Starcoder2ForTokenClassification,
"text-generation": Starcoder2ForCausalLM, "text-generation": Starcoder2ForCausalLM,
"zero-shot": Starcoder2ForSequenceClassification, "zero-shot": Starcoder2ForSequenceClassification,
} }
@@ -370,6 +374,22 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Starcoder2,llama->Starcoder2
def test_Starcoder2_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = Starcoder2ForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Starcoder2 buffers include complex numbers, which breaks this test") @unittest.skip("Starcoder2 buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass