Add TokenClassification for Mistral, Mixtral and Qwen2 (#29878)
* Add MistralForTokenClassification * Add tests and docs * Add token classification for Mixtral and Qwen2 * Save llma for token classification draft * Add token classification support for Llama, Gemma, Persimmon, StableLm and StarCoder2 * Formatting * Add token classification support for Qwen2Moe model * Add dropout layer to each ForTokenClassification model * Add copied from in tests * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Propagate suggested changes * Style --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -60,6 +60,11 @@ This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [
|
|||||||
[[autodoc]] GemmaForSequenceClassification
|
[[autodoc]] GemmaForSequenceClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## GemmaForTokenClassification
|
||||||
|
|
||||||
|
[[autodoc]] GemmaForTokenClassification
|
||||||
|
- forward
|
||||||
|
|
||||||
## FlaxGemmaModel
|
## FlaxGemmaModel
|
||||||
|
|
||||||
[[autodoc]] FlaxGemmaModel
|
[[autodoc]] FlaxGemmaModel
|
||||||
|
|||||||
@@ -121,6 +121,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||||||
[[autodoc]] LlamaForQuestionAnswering
|
[[autodoc]] LlamaForQuestionAnswering
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## LlamaForTokenClassification
|
||||||
|
|
||||||
|
[[autodoc]] LlamaForTokenClassification
|
||||||
|
- forward
|
||||||
|
|
||||||
## FlaxLlamaModel
|
## FlaxLlamaModel
|
||||||
|
|
||||||
[[autodoc]] FlaxLlamaModel
|
[[autodoc]] FlaxLlamaModel
|
||||||
|
|||||||
@@ -203,6 +203,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||||||
[[autodoc]] MistralForSequenceClassification
|
[[autodoc]] MistralForSequenceClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## MistralForTokenClassification
|
||||||
|
|
||||||
|
[[autodoc]] MistralForTokenClassification
|
||||||
|
- forward
|
||||||
|
|
||||||
## FlaxMistralModel
|
## FlaxMistralModel
|
||||||
|
|
||||||
[[autodoc]] FlaxMistralModel
|
[[autodoc]] FlaxMistralModel
|
||||||
|
|||||||
@@ -204,3 +204,8 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||||||
|
|
||||||
[[autodoc]] MixtralForSequenceClassification
|
[[autodoc]] MixtralForSequenceClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## MixtralForTokenClassification
|
||||||
|
|
||||||
|
[[autodoc]] MixtralForTokenClassification
|
||||||
|
- forward
|
||||||
|
|||||||
@@ -96,3 +96,8 @@ The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece. T
|
|||||||
|
|
||||||
[[autodoc]] PersimmonForSequenceClassification
|
[[autodoc]] PersimmonForSequenceClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## PersimmonForTokenClassification
|
||||||
|
|
||||||
|
[[autodoc]] PersimmonForTokenClassification
|
||||||
|
- forward
|
||||||
|
|||||||
@@ -80,3 +80,8 @@ In the following, we demonstrate how to use `Qwen2-7B-Chat-beta` for the inferen
|
|||||||
|
|
||||||
[[autodoc]] Qwen2ForSequenceClassification
|
[[autodoc]] Qwen2ForSequenceClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## Qwen2ForTokenClassification
|
||||||
|
|
||||||
|
[[autodoc]] Qwen2ForTokenClassification
|
||||||
|
- forward
|
||||||
|
|||||||
@@ -75,3 +75,8 @@ In the following, we demonstrate how to use `Qwen1.5-MoE-A2.7B-Chat` for the inf
|
|||||||
|
|
||||||
[[autodoc]] Qwen2MoeForSequenceClassification
|
[[autodoc]] Qwen2MoeForSequenceClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## Qwen2MoeForTokenClassification
|
||||||
|
|
||||||
|
[[autodoc]] Qwen2MoeForTokenClassification
|
||||||
|
- forward
|
||||||
|
|||||||
@@ -104,3 +104,8 @@ Now, to run the model with Flash Attention 2, refer to the snippet below:
|
|||||||
|
|
||||||
[[autodoc]] StableLmForSequenceClassification
|
[[autodoc]] StableLmForSequenceClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## StableLmForTokenClassification
|
||||||
|
|
||||||
|
[[autodoc]] StableLmForTokenClassification
|
||||||
|
- forward
|
||||||
|
|||||||
@@ -66,3 +66,8 @@ These ready-to-use checkpoints can be downloaded and used via the HuggingFace Hu
|
|||||||
|
|
||||||
[[autodoc]] Starcoder2ForSequenceClassification
|
[[autodoc]] Starcoder2ForSequenceClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## Starcoder2ForTokenClassification
|
||||||
|
|
||||||
|
[[autodoc]] Starcoder2ForTokenClassification
|
||||||
|
- forward
|
||||||
|
|||||||
@@ -2031,6 +2031,7 @@ else:
|
|||||||
[
|
[
|
||||||
"GemmaForCausalLM",
|
"GemmaForCausalLM",
|
||||||
"GemmaForSequenceClassification",
|
"GemmaForSequenceClassification",
|
||||||
|
"GemmaForTokenClassification",
|
||||||
"GemmaModel",
|
"GemmaModel",
|
||||||
"GemmaPreTrainedModel",
|
"GemmaPreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -2288,6 +2289,7 @@ else:
|
|||||||
"LlamaForCausalLM",
|
"LlamaForCausalLM",
|
||||||
"LlamaForQuestionAnswering",
|
"LlamaForQuestionAnswering",
|
||||||
"LlamaForSequenceClassification",
|
"LlamaForSequenceClassification",
|
||||||
|
"LlamaForTokenClassification",
|
||||||
"LlamaModel",
|
"LlamaModel",
|
||||||
"LlamaPreTrainedModel",
|
"LlamaPreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -2435,12 +2437,19 @@ else:
|
|||||||
[
|
[
|
||||||
"MistralForCausalLM",
|
"MistralForCausalLM",
|
||||||
"MistralForSequenceClassification",
|
"MistralForSequenceClassification",
|
||||||
|
"MistralForTokenClassification",
|
||||||
"MistralModel",
|
"MistralModel",
|
||||||
"MistralPreTrainedModel",
|
"MistralPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.mixtral"].extend(
|
_import_structure["models.mixtral"].extend(
|
||||||
["MixtralForCausalLM", "MixtralForSequenceClassification", "MixtralModel", "MixtralPreTrainedModel"]
|
[
|
||||||
|
"MixtralForCausalLM",
|
||||||
|
"MixtralForSequenceClassification",
|
||||||
|
"MixtralForTokenClassification",
|
||||||
|
"MixtralModel",
|
||||||
|
"MixtralPreTrainedModel",
|
||||||
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.mobilebert"].extend(
|
_import_structure["models.mobilebert"].extend(
|
||||||
[
|
[
|
||||||
@@ -2714,6 +2723,7 @@ else:
|
|||||||
[
|
[
|
||||||
"PersimmonForCausalLM",
|
"PersimmonForCausalLM",
|
||||||
"PersimmonForSequenceClassification",
|
"PersimmonForSequenceClassification",
|
||||||
|
"PersimmonForTokenClassification",
|
||||||
"PersimmonModel",
|
"PersimmonModel",
|
||||||
"PersimmonPreTrainedModel",
|
"PersimmonPreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -2810,6 +2820,7 @@ else:
|
|||||||
[
|
[
|
||||||
"Qwen2ForCausalLM",
|
"Qwen2ForCausalLM",
|
||||||
"Qwen2ForSequenceClassification",
|
"Qwen2ForSequenceClassification",
|
||||||
|
"Qwen2ForTokenClassification",
|
||||||
"Qwen2Model",
|
"Qwen2Model",
|
||||||
"Qwen2PreTrainedModel",
|
"Qwen2PreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -2818,6 +2829,7 @@ else:
|
|||||||
[
|
[
|
||||||
"Qwen2MoeForCausalLM",
|
"Qwen2MoeForCausalLM",
|
||||||
"Qwen2MoeForSequenceClassification",
|
"Qwen2MoeForSequenceClassification",
|
||||||
|
"Qwen2MoeForTokenClassification",
|
||||||
"Qwen2MoeModel",
|
"Qwen2MoeModel",
|
||||||
"Qwen2MoePreTrainedModel",
|
"Qwen2MoePreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -3066,6 +3078,7 @@ else:
|
|||||||
[
|
[
|
||||||
"StableLmForCausalLM",
|
"StableLmForCausalLM",
|
||||||
"StableLmForSequenceClassification",
|
"StableLmForSequenceClassification",
|
||||||
|
"StableLmForTokenClassification",
|
||||||
"StableLmModel",
|
"StableLmModel",
|
||||||
"StableLmPreTrainedModel",
|
"StableLmPreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -3074,6 +3087,7 @@ else:
|
|||||||
[
|
[
|
||||||
"Starcoder2ForCausalLM",
|
"Starcoder2ForCausalLM",
|
||||||
"Starcoder2ForSequenceClassification",
|
"Starcoder2ForSequenceClassification",
|
||||||
|
"Starcoder2ForTokenClassification",
|
||||||
"Starcoder2Model",
|
"Starcoder2Model",
|
||||||
"Starcoder2PreTrainedModel",
|
"Starcoder2PreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -6489,6 +6503,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.gemma import (
|
from .models.gemma import (
|
||||||
GemmaForCausalLM,
|
GemmaForCausalLM,
|
||||||
GemmaForSequenceClassification,
|
GemmaForSequenceClassification,
|
||||||
|
GemmaForTokenClassification,
|
||||||
GemmaModel,
|
GemmaModel,
|
||||||
GemmaPreTrainedModel,
|
GemmaPreTrainedModel,
|
||||||
)
|
)
|
||||||
@@ -6686,6 +6701,7 @@ if TYPE_CHECKING:
|
|||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaForQuestionAnswering,
|
LlamaForQuestionAnswering,
|
||||||
LlamaForSequenceClassification,
|
LlamaForSequenceClassification,
|
||||||
|
LlamaForTokenClassification,
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
LlamaPreTrainedModel,
|
LlamaPreTrainedModel,
|
||||||
)
|
)
|
||||||
@@ -6801,12 +6817,14 @@ if TYPE_CHECKING:
|
|||||||
from .models.mistral import (
|
from .models.mistral import (
|
||||||
MistralForCausalLM,
|
MistralForCausalLM,
|
||||||
MistralForSequenceClassification,
|
MistralForSequenceClassification,
|
||||||
|
MistralForTokenClassification,
|
||||||
MistralModel,
|
MistralModel,
|
||||||
MistralPreTrainedModel,
|
MistralPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.mixtral import (
|
from .models.mixtral import (
|
||||||
MixtralForCausalLM,
|
MixtralForCausalLM,
|
||||||
MixtralForSequenceClassification,
|
MixtralForSequenceClassification,
|
||||||
|
MixtralForTokenClassification,
|
||||||
MixtralModel,
|
MixtralModel,
|
||||||
MixtralPreTrainedModel,
|
MixtralPreTrainedModel,
|
||||||
)
|
)
|
||||||
@@ -7025,6 +7043,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.persimmon import (
|
from .models.persimmon import (
|
||||||
PersimmonForCausalLM,
|
PersimmonForCausalLM,
|
||||||
PersimmonForSequenceClassification,
|
PersimmonForSequenceClassification,
|
||||||
|
PersimmonForTokenClassification,
|
||||||
PersimmonModel,
|
PersimmonModel,
|
||||||
PersimmonPreTrainedModel,
|
PersimmonPreTrainedModel,
|
||||||
)
|
)
|
||||||
@@ -7099,12 +7118,14 @@ if TYPE_CHECKING:
|
|||||||
from .models.qwen2 import (
|
from .models.qwen2 import (
|
||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
Qwen2ForSequenceClassification,
|
Qwen2ForSequenceClassification,
|
||||||
|
Qwen2ForTokenClassification,
|
||||||
Qwen2Model,
|
Qwen2Model,
|
||||||
Qwen2PreTrainedModel,
|
Qwen2PreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.qwen2_moe import (
|
from .models.qwen2_moe import (
|
||||||
Qwen2MoeForCausalLM,
|
Qwen2MoeForCausalLM,
|
||||||
Qwen2MoeForSequenceClassification,
|
Qwen2MoeForSequenceClassification,
|
||||||
|
Qwen2MoeForTokenClassification,
|
||||||
Qwen2MoeModel,
|
Qwen2MoeModel,
|
||||||
Qwen2MoePreTrainedModel,
|
Qwen2MoePreTrainedModel,
|
||||||
)
|
)
|
||||||
@@ -7306,12 +7327,14 @@ if TYPE_CHECKING:
|
|||||||
from .models.stablelm import (
|
from .models.stablelm import (
|
||||||
StableLmForCausalLM,
|
StableLmForCausalLM,
|
||||||
StableLmForSequenceClassification,
|
StableLmForSequenceClassification,
|
||||||
|
StableLmForTokenClassification,
|
||||||
StableLmModel,
|
StableLmModel,
|
||||||
StableLmPreTrainedModel,
|
StableLmPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.starcoder2 import (
|
from .models.starcoder2 import (
|
||||||
Starcoder2ForCausalLM,
|
Starcoder2ForCausalLM,
|
||||||
Starcoder2ForSequenceClassification,
|
Starcoder2ForSequenceClassification,
|
||||||
|
Starcoder2ForTokenClassification,
|
||||||
Starcoder2Model,
|
Starcoder2Model,
|
||||||
Starcoder2PreTrainedModel,
|
Starcoder2PreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1038,6 +1038,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("flaubert", "FlaubertForTokenClassification"),
|
("flaubert", "FlaubertForTokenClassification"),
|
||||||
("fnet", "FNetForTokenClassification"),
|
("fnet", "FNetForTokenClassification"),
|
||||||
("funnel", "FunnelForTokenClassification"),
|
("funnel", "FunnelForTokenClassification"),
|
||||||
|
("gemma", "GemmaForTokenClassification"),
|
||||||
("gpt-sw3", "GPT2ForTokenClassification"),
|
("gpt-sw3", "GPT2ForTokenClassification"),
|
||||||
("gpt2", "GPT2ForTokenClassification"),
|
("gpt2", "GPT2ForTokenClassification"),
|
||||||
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
|
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
|
||||||
@@ -1048,11 +1049,14 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
|
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
|
||||||
("layoutlmv3", "LayoutLMv3ForTokenClassification"),
|
("layoutlmv3", "LayoutLMv3ForTokenClassification"),
|
||||||
("lilt", "LiltForTokenClassification"),
|
("lilt", "LiltForTokenClassification"),
|
||||||
|
("llama", "LlamaForTokenClassification"),
|
||||||
("longformer", "LongformerForTokenClassification"),
|
("longformer", "LongformerForTokenClassification"),
|
||||||
("luke", "LukeForTokenClassification"),
|
("luke", "LukeForTokenClassification"),
|
||||||
("markuplm", "MarkupLMForTokenClassification"),
|
("markuplm", "MarkupLMForTokenClassification"),
|
||||||
("mega", "MegaForTokenClassification"),
|
("mega", "MegaForTokenClassification"),
|
||||||
("megatron-bert", "MegatronBertForTokenClassification"),
|
("megatron-bert", "MegatronBertForTokenClassification"),
|
||||||
|
("mistral", "MistralForTokenClassification"),
|
||||||
|
("mixtral", "MixtralForTokenClassification"),
|
||||||
("mobilebert", "MobileBertForTokenClassification"),
|
("mobilebert", "MobileBertForTokenClassification"),
|
||||||
("mpnet", "MPNetForTokenClassification"),
|
("mpnet", "MPNetForTokenClassification"),
|
||||||
("mpt", "MptForTokenClassification"),
|
("mpt", "MptForTokenClassification"),
|
||||||
@@ -1060,15 +1064,20 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("mt5", "MT5ForTokenClassification"),
|
("mt5", "MT5ForTokenClassification"),
|
||||||
("nezha", "NezhaForTokenClassification"),
|
("nezha", "NezhaForTokenClassification"),
|
||||||
("nystromformer", "NystromformerForTokenClassification"),
|
("nystromformer", "NystromformerForTokenClassification"),
|
||||||
|
("persimmon", "PersimmonForTokenClassification"),
|
||||||
("phi", "PhiForTokenClassification"),
|
("phi", "PhiForTokenClassification"),
|
||||||
("phi3", "Phi3ForTokenClassification"),
|
("phi3", "Phi3ForTokenClassification"),
|
||||||
("qdqbert", "QDQBertForTokenClassification"),
|
("qdqbert", "QDQBertForTokenClassification"),
|
||||||
|
("qwen2", "Qwen2ForTokenClassification"),
|
||||||
|
("qwen2_moe", "Qwen2MoeForTokenClassification"),
|
||||||
("rembert", "RemBertForTokenClassification"),
|
("rembert", "RemBertForTokenClassification"),
|
||||||
("roberta", "RobertaForTokenClassification"),
|
("roberta", "RobertaForTokenClassification"),
|
||||||
("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"),
|
("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"),
|
||||||
("roc_bert", "RoCBertForTokenClassification"),
|
("roc_bert", "RoCBertForTokenClassification"),
|
||||||
("roformer", "RoFormerForTokenClassification"),
|
("roformer", "RoFormerForTokenClassification"),
|
||||||
("squeezebert", "SqueezeBertForTokenClassification"),
|
("squeezebert", "SqueezeBertForTokenClassification"),
|
||||||
|
("stablelm", "StableLmForTokenClassification"),
|
||||||
|
("starcoder2", "Starcoder2ForTokenClassification"),
|
||||||
("t5", "T5ForTokenClassification"),
|
("t5", "T5ForTokenClassification"),
|
||||||
("umt5", "UMT5ForTokenClassification"),
|
("umt5", "UMT5ForTokenClassification"),
|
||||||
("xlm", "XLMForTokenClassification"),
|
("xlm", "XLMForTokenClassification"),
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ else:
|
|||||||
"GemmaModel",
|
"GemmaModel",
|
||||||
"GemmaPreTrainedModel",
|
"GemmaPreTrainedModel",
|
||||||
"GemmaForSequenceClassification",
|
"GemmaForSequenceClassification",
|
||||||
|
"GemmaForTokenClassification",
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -98,6 +99,7 @@ if TYPE_CHECKING:
|
|||||||
from .modeling_gemma import (
|
from .modeling_gemma import (
|
||||||
GemmaForCausalLM,
|
GemmaForCausalLM,
|
||||||
GemmaForSequenceClassification,
|
GemmaForSequenceClassification,
|
||||||
|
GemmaForTokenClassification,
|
||||||
GemmaModel,
|
GemmaModel,
|
||||||
GemmaPreTrainedModel,
|
GemmaPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -30,7 +30,12 @@ from ...modeling_attn_mask_utils import (
|
|||||||
AttentionMaskConverter,
|
AttentionMaskConverter,
|
||||||
_prepare_4d_causal_attention_mask,
|
_prepare_4d_causal_attention_mask,
|
||||||
)
|
)
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
from ...modeling_outputs import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
|
SequenceClassifierOutputWithPast,
|
||||||
|
TokenClassifierOutput,
|
||||||
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@@ -1346,3 +1351,88 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
|||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The Gemma Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||||
|
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||||
|
""",
|
||||||
|
GEMMA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Gemma, LLAMA->GEMMA
|
||||||
|
class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.model = GemmaModel(config)
|
||||||
|
if getattr(config, "classifier_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.classifier_dropout
|
||||||
|
elif getattr(config, "hidden_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.hidden_dropout
|
||||||
|
else:
|
||||||
|
classifier_dropout = 0.1
|
||||||
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
|
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
sequence_output = self.dropout(sequence_output)
|
||||||
|
logits = self.score(sequence_output)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TokenClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ else:
|
|||||||
"LlamaPreTrainedModel",
|
"LlamaPreTrainedModel",
|
||||||
"LlamaForSequenceClassification",
|
"LlamaForSequenceClassification",
|
||||||
"LlamaForQuestionAnswering",
|
"LlamaForQuestionAnswering",
|
||||||
|
"LlamaForTokenClassification",
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -95,6 +96,7 @@ if TYPE_CHECKING:
|
|||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaForQuestionAnswering,
|
LlamaForQuestionAnswering,
|
||||||
LlamaForSequenceClassification,
|
LlamaForSequenceClassification,
|
||||||
|
LlamaForTokenClassification,
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
LlamaPreTrainedModel,
|
LlamaPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from ...modeling_outputs import (
|
|||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
QuestionAnsweringModelOutput,
|
QuestionAnsweringModelOutput,
|
||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||||
@@ -1516,3 +1517,87 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
|
|||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||||
|
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||||
|
""",
|
||||||
|
LLAMA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class LlamaForTokenClassification(LlamaPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.model = LlamaModel(config)
|
||||||
|
if getattr(config, "classifier_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.classifier_dropout
|
||||||
|
elif getattr(config, "hidden_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.hidden_dropout
|
||||||
|
else:
|
||||||
|
classifier_dropout = 0.1
|
||||||
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
|
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
sequence_output = self.dropout(sequence_output)
|
||||||
|
logits = self.score(sequence_output)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TokenClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ else:
|
|||||||
"MistralModel",
|
"MistralModel",
|
||||||
"MistralPreTrainedModel",
|
"MistralPreTrainedModel",
|
||||||
"MistralForSequenceClassification",
|
"MistralForSequenceClassification",
|
||||||
|
"MistralForTokenClassification",
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -59,6 +60,7 @@ if TYPE_CHECKING:
|
|||||||
from .modeling_mistral import (
|
from .modeling_mistral import (
|
||||||
MistralForCausalLM,
|
MistralForCausalLM,
|
||||||
MistralForSequenceClassification,
|
MistralForSequenceClassification,
|
||||||
|
MistralForTokenClassification,
|
||||||
MistralModel,
|
MistralModel,
|
||||||
MistralPreTrainedModel,
|
MistralPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache
|
from ...cache_utils import Cache, DynamicCache
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
from ...modeling_outputs import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
|
SequenceClassifierOutputWithPast,
|
||||||
|
TokenClassifierOutput,
|
||||||
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -1366,3 +1371,88 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
|
|||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||||
|
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||||
|
""",
|
||||||
|
MISTRAL_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mistral, LLAMA->MISTRAL
|
||||||
|
class MistralForTokenClassification(MistralPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.model = MistralModel(config)
|
||||||
|
if getattr(config, "classifier_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.classifier_dropout
|
||||||
|
elif getattr(config, "hidden_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.hidden_dropout
|
||||||
|
else:
|
||||||
|
classifier_dropout = 0.1
|
||||||
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
|
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
sequence_output = self.dropout(sequence_output)
|
||||||
|
logits = self.score(sequence_output)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TokenClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ else:
|
|||||||
"MixtralModel",
|
"MixtralModel",
|
||||||
"MixtralPreTrainedModel",
|
"MixtralPreTrainedModel",
|
||||||
"MixtralForSequenceClassification",
|
"MixtralForSequenceClassification",
|
||||||
|
"MixtralForTokenClassification",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|||||||
from .modeling_mixtral import (
|
from .modeling_mixtral import (
|
||||||
MixtralForCausalLM,
|
MixtralForCausalLM,
|
||||||
MixtralForSequenceClassification,
|
MixtralForSequenceClassification,
|
||||||
|
MixtralForTokenClassification,
|
||||||
MixtralModel,
|
MixtralModel,
|
||||||
MixtralPreTrainedModel,
|
MixtralPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from ...modeling_outputs import (
|
|||||||
MoeCausalLMOutputWithPast,
|
MoeCausalLMOutputWithPast,
|
||||||
MoeModelOutputWithPast,
|
MoeModelOutputWithPast,
|
||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
|
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
|
||||||
@@ -1582,3 +1583,88 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel):
|
|||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The Mixtral Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||||
|
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||||
|
""",
|
||||||
|
MIXTRAL_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL
|
||||||
|
class MixtralForTokenClassification(MixtralPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.model = MixtralModel(config)
|
||||||
|
if getattr(config, "classifier_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.classifier_dropout
|
||||||
|
elif getattr(config, "hidden_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.hidden_dropout
|
||||||
|
else:
|
||||||
|
classifier_dropout = 0.1
|
||||||
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
|
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
sequence_output = self.dropout(sequence_output)
|
||||||
|
logits = self.score(sequence_output)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TokenClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ else:
|
|||||||
"PersimmonModel",
|
"PersimmonModel",
|
||||||
"PersimmonPreTrainedModel",
|
"PersimmonPreTrainedModel",
|
||||||
"PersimmonForSequenceClassification",
|
"PersimmonForSequenceClassification",
|
||||||
|
"PersimmonForTokenClassification",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|||||||
from .modeling_persimmon import (
|
from .modeling_persimmon import (
|
||||||
PersimmonForCausalLM,
|
PersimmonForCausalLM,
|
||||||
PersimmonForSequenceClassification,
|
PersimmonForSequenceClassification,
|
||||||
|
PersimmonForTokenClassification,
|
||||||
PersimmonModel,
|
PersimmonModel,
|
||||||
PersimmonPreTrainedModel,
|
PersimmonPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -29,7 +29,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache
|
from ...cache_utils import Cache, DynamicCache
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
from ...modeling_outputs import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
|
SequenceClassifierOutputWithPast,
|
||||||
|
TokenClassifierOutput,
|
||||||
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
from .configuration_persimmon import PersimmonConfig
|
from .configuration_persimmon import PersimmonConfig
|
||||||
@@ -1011,3 +1016,88 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
|
|||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The Persimmon Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||||
|
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||||
|
""",
|
||||||
|
PERSIMMON_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Persimmon, LLAMA->PERSIMMON
|
||||||
|
class PersimmonForTokenClassification(PersimmonPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.model = PersimmonModel(config)
|
||||||
|
if getattr(config, "classifier_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.classifier_dropout
|
||||||
|
elif getattr(config, "hidden_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.hidden_dropout
|
||||||
|
else:
|
||||||
|
classifier_dropout = 0.1
|
||||||
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
|
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
sequence_output = self.dropout(sequence_output)
|
||||||
|
logits = self.score(sequence_output)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TokenClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ else:
|
|||||||
"Qwen2Model",
|
"Qwen2Model",
|
||||||
"Qwen2PreTrainedModel",
|
"Qwen2PreTrainedModel",
|
||||||
"Qwen2ForSequenceClassification",
|
"Qwen2ForSequenceClassification",
|
||||||
|
"Qwen2ForTokenClassification",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -69,6 +70,7 @@ if TYPE_CHECKING:
|
|||||||
from .modeling_qwen2 import (
|
from .modeling_qwen2 import (
|
||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
Qwen2ForSequenceClassification,
|
Qwen2ForSequenceClassification,
|
||||||
|
Qwen2ForTokenClassification,
|
||||||
Qwen2Model,
|
Qwen2Model,
|
||||||
Qwen2PreTrainedModel,
|
Qwen2PreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache
|
from ...cache_utils import Cache, DynamicCache
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
from ...modeling_outputs import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
|
SequenceClassifierOutputWithPast,
|
||||||
|
TokenClassifierOutput,
|
||||||
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -1375,3 +1380,88 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
|
|||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||||
|
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||||
|
""",
|
||||||
|
QWEN2_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2
|
||||||
|
class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.model = Qwen2Model(config)
|
||||||
|
if getattr(config, "classifier_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.classifier_dropout
|
||||||
|
elif getattr(config, "hidden_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.hidden_dropout
|
||||||
|
else:
|
||||||
|
classifier_dropout = 0.1
|
||||||
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
|
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
sequence_output = self.dropout(sequence_output)
|
||||||
|
logits = self.score(sequence_output)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TokenClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ else:
|
|||||||
"Qwen2MoeModel",
|
"Qwen2MoeModel",
|
||||||
"Qwen2MoePreTrainedModel",
|
"Qwen2MoePreTrainedModel",
|
||||||
"Qwen2MoeForSequenceClassification",
|
"Qwen2MoeForSequenceClassification",
|
||||||
|
"Qwen2MoeForTokenClassification",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|||||||
from .modeling_qwen2_moe import (
|
from .modeling_qwen2_moe import (
|
||||||
Qwen2MoeForCausalLM,
|
Qwen2MoeForCausalLM,
|
||||||
Qwen2MoeForSequenceClassification,
|
Qwen2MoeForSequenceClassification,
|
||||||
|
Qwen2MoeForTokenClassification,
|
||||||
Qwen2MoeModel,
|
Qwen2MoeModel,
|
||||||
Qwen2MoePreTrainedModel,
|
Qwen2MoePreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -32,7 +32,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache
|
from ...cache_utils import Cache, DynamicCache
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast
|
from ...modeling_outputs import (
|
||||||
|
MoeCausalLMOutputWithPast,
|
||||||
|
MoeModelOutputWithPast,
|
||||||
|
SequenceClassifierOutputWithPast,
|
||||||
|
TokenClassifierOutput,
|
||||||
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -1571,3 +1576,88 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):
|
|||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The Qwen2MoE Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||||
|
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||||
|
""",
|
||||||
|
QWEN2MOE_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE
|
||||||
|
class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.model = Qwen2MoeModel(config)
|
||||||
|
if getattr(config, "classifier_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.classifier_dropout
|
||||||
|
elif getattr(config, "hidden_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.hidden_dropout
|
||||||
|
else:
|
||||||
|
classifier_dropout = 0.1
|
||||||
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
|
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
sequence_output = self.dropout(sequence_output)
|
||||||
|
logits = self.score(sequence_output)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TokenClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ else:
|
|||||||
"StableLmModel",
|
"StableLmModel",
|
||||||
"StableLmPreTrainedModel",
|
"StableLmPreTrainedModel",
|
||||||
"StableLmForSequenceClassification",
|
"StableLmForSequenceClassification",
|
||||||
|
"StableLmForTokenClassification",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|||||||
from .modeling_stablelm import (
|
from .modeling_stablelm import (
|
||||||
StableLmForCausalLM,
|
StableLmForCausalLM,
|
||||||
StableLmForSequenceClassification,
|
StableLmForSequenceClassification,
|
||||||
|
StableLmForTokenClassification,
|
||||||
StableLmModel,
|
StableLmModel,
|
||||||
StableLmPreTrainedModel,
|
StableLmPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -30,7 +30,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache
|
from ...cache_utils import Cache, DynamicCache
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
from ...modeling_outputs import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
|
SequenceClassifierOutputWithPast,
|
||||||
|
TokenClassifierOutput,
|
||||||
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -1383,3 +1388,88 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel):
|
|||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The StableLm Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||||
|
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||||
|
""",
|
||||||
|
STABLELM_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->StableLm, LLAMA->STABLELM
|
||||||
|
class StableLmForTokenClassification(StableLmPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.model = StableLmModel(config)
|
||||||
|
if getattr(config, "classifier_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.classifier_dropout
|
||||||
|
elif getattr(config, "hidden_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.hidden_dropout
|
||||||
|
else:
|
||||||
|
classifier_dropout = 0.1
|
||||||
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
|
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
sequence_output = self.dropout(sequence_output)
|
||||||
|
logits = self.score(sequence_output)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TokenClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ else:
|
|||||||
"Starcoder2Model",
|
"Starcoder2Model",
|
||||||
"Starcoder2PreTrainedModel",
|
"Starcoder2PreTrainedModel",
|
||||||
"Starcoder2ForSequenceClassification",
|
"Starcoder2ForSequenceClassification",
|
||||||
|
"Starcoder2ForTokenClassification",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -51,6 +52,7 @@ if TYPE_CHECKING:
|
|||||||
from .modeling_starcoder2 import (
|
from .modeling_starcoder2 import (
|
||||||
Starcoder2ForCausalLM,
|
Starcoder2ForCausalLM,
|
||||||
Starcoder2ForSequenceClassification,
|
Starcoder2ForSequenceClassification,
|
||||||
|
Starcoder2ForTokenClassification,
|
||||||
Starcoder2Model,
|
Starcoder2Model,
|
||||||
Starcoder2PreTrainedModel,
|
Starcoder2PreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache
|
from ...cache_utils import Cache, DynamicCache
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
from ...modeling_outputs import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
|
SequenceClassifierOutputWithPast,
|
||||||
|
TokenClassifierOutput,
|
||||||
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -1357,3 +1362,88 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
|
|||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The Starcoder2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
||||||
|
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
||||||
|
""",
|
||||||
|
STARCODER2_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Starcoder2, LLAMA->STARCODER2
|
||||||
|
class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.model = Starcoder2Model(config)
|
||||||
|
if getattr(config, "classifier_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.classifier_dropout
|
||||||
|
elif getattr(config, "hidden_dropout", None) is not None:
|
||||||
|
classifier_dropout = config.hidden_dropout
|
||||||
|
else:
|
||||||
|
classifier_dropout = 0.1
|
||||||
|
self.dropout = nn.Dropout(classifier_dropout)
|
||||||
|
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
sequence_output = self.dropout(sequence_output)
|
||||||
|
logits = self.score(sequence_output)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TokenClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -3692,6 +3692,13 @@ class GemmaForSequenceClassification(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaForTokenClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class GemmaModel(metaclass=DummyObject):
|
class GemmaModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -4642,6 +4649,13 @@ class LlamaForSequenceClassification(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaForTokenClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class LlamaModel(metaclass=DummyObject):
|
class LlamaModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -5237,6 +5251,13 @@ class MistralForSequenceClassification(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class MistralForTokenClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class MistralModel(metaclass=DummyObject):
|
class MistralModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -5265,6 +5286,13 @@ class MixtralForSequenceClassification(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralForTokenClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class MixtralModel(metaclass=DummyObject):
|
class MixtralModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -6373,6 +6401,13 @@ class PersimmonForSequenceClassification(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class PersimmonForTokenClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class PersimmonModel(metaclass=DummyObject):
|
class PersimmonModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -6734,6 +6769,13 @@ class Qwen2ForSequenceClassification(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2ForTokenClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class Qwen2Model(metaclass=DummyObject):
|
class Qwen2Model(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -6762,6 +6804,13 @@ class Qwen2MoeForSequenceClassification(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2MoeForTokenClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeModel(metaclass=DummyObject):
|
class Qwen2MoeModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -7793,6 +7842,13 @@ class StableLmForSequenceClassification(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class StableLmForTokenClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class StableLmModel(metaclass=DummyObject):
|
class StableLmModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -7821,6 +7877,13 @@ class Starcoder2ForSequenceClassification(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Starcoder2ForTokenClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2Model(metaclass=DummyObject):
|
class Starcoder2Model(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel, GemmaTokenizer
|
from transformers import (
|
||||||
|
GemmaForCausalLM,
|
||||||
|
GemmaForSequenceClassification,
|
||||||
|
GemmaForTokenClassification,
|
||||||
|
GemmaModel,
|
||||||
|
GemmaTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GemmaModelTester:
|
class GemmaModelTester:
|
||||||
@@ -284,12 +290,17 @@ class GemmaModelTester:
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification) if is_torch_available() else ()
|
all_model_classes = (
|
||||||
|
(GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
all_generative_model_classes = (GemmaForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (GemmaForCausalLM,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": GemmaModel,
|
"feature-extraction": GemmaModel,
|
||||||
"text-classification": GemmaForSequenceClassification,
|
"text-classification": GemmaForSequenceClassification,
|
||||||
|
"token-classification": GemmaForTokenClassification,
|
||||||
"text-generation": GemmaForCausalLM,
|
"text-generation": GemmaForCausalLM,
|
||||||
"zero-shot": GemmaForSequenceClassification,
|
"zero-shot": GemmaForSequenceClassification,
|
||||||
}
|
}
|
||||||
@@ -370,6 +381,22 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Gemma,llama->Gemma
|
||||||
|
def test_Gemma_token_classification_model(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||||
|
model = GemmaForTokenClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||||
|
self.assertEqual(
|
||||||
|
result.logits.shape,
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
|
)
|
||||||
|
|
||||||
@unittest.skip("Gemma buffers include complex numbers, which breaks this test")
|
@unittest.skip("Gemma buffers include complex numbers, which breaks this test")
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ if is_torch_available():
|
|||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaForQuestionAnswering,
|
LlamaForQuestionAnswering,
|
||||||
LlamaForSequenceClassification,
|
LlamaForSequenceClassification,
|
||||||
|
LlamaForTokenClassification,
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
LlamaTokenizer,
|
LlamaTokenizer,
|
||||||
)
|
)
|
||||||
@@ -286,7 +287,13 @@ class LlamaModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForQuestionAnswering)
|
(
|
||||||
|
LlamaModel,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
LlamaForSequenceClassification,
|
||||||
|
LlamaForQuestionAnswering,
|
||||||
|
LlamaForTokenClassification,
|
||||||
|
)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
@@ -298,6 +305,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
"text-generation": LlamaForCausalLM,
|
"text-generation": LlamaForCausalLM,
|
||||||
"zero-shot": LlamaForSequenceClassification,
|
"zero-shot": LlamaForSequenceClassification,
|
||||||
"question-answering": LlamaForQuestionAnswering,
|
"question-answering": LlamaForQuestionAnswering,
|
||||||
|
"token-classification": LlamaForTokenClassification,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
@@ -370,6 +378,21 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
|
def test_llama_token_classification_model(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||||
|
model = LlamaForTokenClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||||
|
self.assertEqual(
|
||||||
|
result.logits.shape,
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
|
)
|
||||||
|
|
||||||
@unittest.skip("Llama buffers include complex numbers, which breaks this test")
|
@unittest.skip("Llama buffers include complex numbers, which breaks this test")
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
MistralForCausalLM,
|
MistralForCausalLM,
|
||||||
MistralForSequenceClassification,
|
MistralForSequenceClassification,
|
||||||
|
MistralForTokenClassification,
|
||||||
MistralModel,
|
MistralModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -288,13 +289,16 @@ class MistralModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(MistralModel, MistralForCausalLM, MistralForSequenceClassification) if is_torch_available() else ()
|
(MistralModel, MistralForCausalLM, MistralForSequenceClassification, MistralForTokenClassification)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
)
|
)
|
||||||
all_generative_model_classes = (MistralForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (MistralForCausalLM,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": MistralModel,
|
"feature-extraction": MistralModel,
|
||||||
"text-classification": MistralForSequenceClassification,
|
"text-classification": MistralForSequenceClassification,
|
||||||
|
"token-classification": MistralForTokenClassification,
|
||||||
"text-generation": MistralForCausalLM,
|
"text-generation": MistralForCausalLM,
|
||||||
"zero-shot": MistralForSequenceClassification,
|
"zero-shot": MistralForSequenceClassification,
|
||||||
}
|
}
|
||||||
@@ -376,6 +380,22 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mistral,llama->Mistral
|
||||||
|
def test_Mistral_token_classification_model(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||||
|
model = MistralForTokenClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||||
|
self.assertEqual(
|
||||||
|
result.logits.shape,
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
|
)
|
||||||
|
|
||||||
@unittest.skip("Mistral buffers include complex numbers, which breaks this test")
|
@unittest.skip("Mistral buffers include complex numbers, which breaks this test")
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -40,7 +40,12 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel
|
from transformers import (
|
||||||
|
MixtralForCausalLM,
|
||||||
|
MixtralForSequenceClassification,
|
||||||
|
MixtralForTokenClassification,
|
||||||
|
MixtralModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MixtralModelTester:
|
class MixtralModelTester:
|
||||||
@@ -287,13 +292,16 @@ class MixtralModelTester:
|
|||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral
|
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral
|
||||||
class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification) if is_torch_available() else ()
|
(MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification, MixtralForTokenClassification)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
)
|
)
|
||||||
all_generative_model_classes = (MixtralForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (MixtralForCausalLM,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": MixtralModel,
|
"feature-extraction": MixtralModel,
|
||||||
"text-classification": MixtralForSequenceClassification,
|
"text-classification": MixtralForSequenceClassification,
|
||||||
|
"token-classification": MixtralForTokenClassification,
|
||||||
"text-generation": MixtralForCausalLM,
|
"text-generation": MixtralForCausalLM,
|
||||||
"zero-shot": MixtralForSequenceClassification,
|
"zero-shot": MixtralForSequenceClassification,
|
||||||
}
|
}
|
||||||
@@ -375,6 +383,22 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mixtral,llama->Mixtral
|
||||||
|
def test_Mixtral_token_classification_model(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||||
|
model = MixtralForTokenClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||||
|
self.assertEqual(
|
||||||
|
result.logits.shape,
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
|
)
|
||||||
|
|
||||||
@unittest.skip("Mixtral buffers include complex numbers, which breaks this test")
|
@unittest.skip("Mixtral buffers include complex numbers, which breaks this test")
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ if is_torch_available():
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
PersimmonForCausalLM,
|
PersimmonForCausalLM,
|
||||||
PersimmonForSequenceClassification,
|
PersimmonForSequenceClassification,
|
||||||
|
PersimmonForTokenClassification,
|
||||||
PersimmonModel,
|
PersimmonModel,
|
||||||
)
|
)
|
||||||
from transformers.models.persimmon.modeling_persimmon import (
|
from transformers.models.persimmon.modeling_persimmon import (
|
||||||
@@ -283,12 +284,15 @@ class PersimmonModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification) if is_torch_available() else ()
|
(PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification, PersimmonForTokenClassification)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": PersimmonModel,
|
"feature-extraction": PersimmonModel,
|
||||||
"text-classification": PersimmonForSequenceClassification,
|
"text-classification": PersimmonForSequenceClassification,
|
||||||
|
"token-classification": PersimmonForTokenClassification,
|
||||||
# TODO (ydshieh): check why these two fail. Fix them or skip them in a better way.
|
# TODO (ydshieh): check why these two fail. Fix them or skip them in a better way.
|
||||||
# "text-generation": PersimmonForCausalLM,
|
# "text-generation": PersimmonForCausalLM,
|
||||||
# "zero-shot": PersimmonForSequenceClassification,
|
# "zero-shot": PersimmonForSequenceClassification,
|
||||||
@@ -365,6 +369,22 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Persimmon,llama->persimmon
|
||||||
|
def test_persimmon_token_classification_model(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||||
|
model = PersimmonForTokenClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||||
|
self.assertEqual(
|
||||||
|
result.logits.shape,
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
|
)
|
||||||
|
|
||||||
@unittest.skip("Persimmon buffers include complex numbers, which breaks this test")
|
@unittest.skip("Persimmon buffers include complex numbers, which breaks this test")
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
Qwen2ForSequenceClassification,
|
Qwen2ForSequenceClassification,
|
||||||
|
Qwen2ForTokenClassification,
|
||||||
Qwen2Model,
|
Qwen2Model,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -299,12 +300,17 @@ class Qwen2ModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2
|
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2
|
||||||
class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification) if is_torch_available() else ()
|
all_model_classes = (
|
||||||
|
(Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification, Qwen2ForTokenClassification)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
all_generative_model_classes = (Qwen2ForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (Qwen2ForCausalLM,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": Qwen2Model,
|
"feature-extraction": Qwen2Model,
|
||||||
"text-classification": Qwen2ForSequenceClassification,
|
"text-classification": Qwen2ForSequenceClassification,
|
||||||
|
"token-classification": Qwen2ForTokenClassification,
|
||||||
"text-generation": Qwen2ForCausalLM,
|
"text-generation": Qwen2ForCausalLM,
|
||||||
"zero-shot": Qwen2ForSequenceClassification,
|
"zero-shot": Qwen2ForSequenceClassification,
|
||||||
}
|
}
|
||||||
@@ -387,6 +393,22 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2,llama->Qwen2
|
||||||
|
def test_Qwen2_token_classification_model(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||||
|
model = Qwen2ForTokenClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||||
|
self.assertEqual(
|
||||||
|
result.logits.shape,
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
|
)
|
||||||
|
|
||||||
@unittest.skip("Qwen2 buffers include complex numbers, which breaks this test")
|
@unittest.skip("Qwen2 buffers include complex numbers, which breaks this test")
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
Qwen2MoeForCausalLM,
|
Qwen2MoeForCausalLM,
|
||||||
Qwen2MoeForSequenceClassification,
|
Qwen2MoeForSequenceClassification,
|
||||||
|
Qwen2MoeForTokenClassification,
|
||||||
Qwen2MoeModel,
|
Qwen2MoeModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -327,13 +328,16 @@ class Qwen2MoeModelTester:
|
|||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe
|
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe
|
||||||
class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification) if is_torch_available() else ()
|
(Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, Qwen2MoeForTokenClassification)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
)
|
)
|
||||||
all_generative_model_classes = (Qwen2MoeForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (Qwen2MoeForCausalLM,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": Qwen2MoeModel,
|
"feature-extraction": Qwen2MoeModel,
|
||||||
"text-classification": Qwen2MoeForSequenceClassification,
|
"text-classification": Qwen2MoeForSequenceClassification,
|
||||||
|
"token-classification": Qwen2MoeForTokenClassification,
|
||||||
"text-generation": Qwen2MoeForCausalLM,
|
"text-generation": Qwen2MoeForCausalLM,
|
||||||
"zero-shot": Qwen2MoeForSequenceClassification,
|
"zero-shot": Qwen2MoeForSequenceClassification,
|
||||||
}
|
}
|
||||||
@@ -414,6 +418,22 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2Moe,llama->Qwen2Moe
|
||||||
|
def test_Qwen2Moe_token_classification_model(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||||
|
model = Qwen2MoeForTokenClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||||
|
self.assertEqual(
|
||||||
|
result.logits.shape,
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
|
)
|
||||||
|
|
||||||
@unittest.skip("Qwen2Moe buffers include complex numbers, which breaks this test")
|
@unittest.skip("Qwen2Moe buffers include complex numbers, which breaks this test")
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ if is_torch_available():
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
StableLmForCausalLM,
|
StableLmForCausalLM,
|
||||||
StableLmForSequenceClassification,
|
StableLmForSequenceClassification,
|
||||||
|
StableLmForTokenClassification,
|
||||||
StableLmModel,
|
StableLmModel,
|
||||||
)
|
)
|
||||||
from transformers.models.stablelm.modeling_stablelm import (
|
from transformers.models.stablelm.modeling_stablelm import (
|
||||||
@@ -287,12 +288,15 @@ class StableLmModelTester:
|
|||||||
# Copied from transformers.tests.persimmon.test_modeling_persimmon.PersimmonModelTest with Persimmon -> StableLm
|
# Copied from transformers.tests.persimmon.test_modeling_persimmon.PersimmonModelTest with Persimmon -> StableLm
|
||||||
class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification) if is_torch_available() else ()
|
(StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification, StableLmForTokenClassification)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": StableLmModel,
|
"feature-extraction": StableLmModel,
|
||||||
"text-classification": StableLmForSequenceClassification,
|
"text-classification": StableLmForSequenceClassification,
|
||||||
|
"token-classification": StableLmForTokenClassification,
|
||||||
# TODO (ydshieh): check why these two fail. Fix them or skip them in a better way.
|
# TODO (ydshieh): check why these two fail. Fix them or skip them in a better way.
|
||||||
# "text-generation": StableLmForCausalLM,
|
# "text-generation": StableLmForCausalLM,
|
||||||
# "zero-shot": StableLmForSequenceClassification,
|
# "zero-shot": StableLmForSequenceClassification,
|
||||||
@@ -356,6 +360,22 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->StableLm,llama->stablelm
|
||||||
|
def test_stablelm_token_classification_model(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||||
|
model = StableLmForTokenClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||||
|
self.assertEqual(
|
||||||
|
result.logits.shape,
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
|
)
|
||||||
|
|
||||||
@parameterized.expand([("linear",), ("dynamic",)])
|
@parameterized.expand([("linear",), ("dynamic",)])
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm
|
||||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
def test_model_rope_scaling_from_config(self, scaling_type):
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ if is_torch_available():
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
Starcoder2ForCausalLM,
|
Starcoder2ForCausalLM,
|
||||||
Starcoder2ForSequenceClassification,
|
Starcoder2ForSequenceClassification,
|
||||||
|
Starcoder2ForTokenClassification,
|
||||||
Starcoder2Model,
|
Starcoder2Model,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -290,13 +291,16 @@ class Starcoder2ModelTester:
|
|||||||
# Copied from transformers.tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Starcoder2
|
# Copied from transformers.tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Starcoder2
|
||||||
class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification) if is_torch_available() else ()
|
(Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, Starcoder2ForTokenClassification)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
)
|
)
|
||||||
all_generative_model_classes = (Starcoder2ForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (Starcoder2ForCausalLM,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": Starcoder2Model,
|
"feature-extraction": Starcoder2Model,
|
||||||
"text-classification": Starcoder2ForSequenceClassification,
|
"text-classification": Starcoder2ForSequenceClassification,
|
||||||
|
"token-classification": Starcoder2ForTokenClassification,
|
||||||
"text-generation": Starcoder2ForCausalLM,
|
"text-generation": Starcoder2ForCausalLM,
|
||||||
"zero-shot": Starcoder2ForSequenceClassification,
|
"zero-shot": Starcoder2ForSequenceClassification,
|
||||||
}
|
}
|
||||||
@@ -370,6 +374,22 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Starcoder2,llama->Starcoder2
|
||||||
|
def test_Starcoder2_token_classification_model(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||||
|
model = Starcoder2ForTokenClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||||
|
self.assertEqual(
|
||||||
|
result.logits.shape,
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
|
)
|
||||||
|
|
||||||
@unittest.skip("Starcoder2 buffers include complex numbers, which breaks this test")
|
@unittest.skip("Starcoder2 buffers include complex numbers, which breaks this test")
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user