From 07bf2dff78dce04314447efe46ccb70480374daf Mon Sep 17 00:00:00 2001 From: Joseph Enguehard Date: Mon, 20 May 2024 09:06:57 +0100 Subject: [PATCH] Add TokenClassification for Mistral, Mixtral and Qwen2 (#29878) * Add MistralForTokenClassification * Add tests and docs * Add token classification for Mixtral and Qwen2 * Save llma for token classification draft * Add token classification support for Llama, Gemma, Persimmon, StableLm and StarCoder2 * Formatting * Add token classification support for Qwen2Moe model * Add dropout layer to each ForTokenClassification model * Add copied from in tests * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Propagate suggested changes * Style --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- docs/source/en/model_doc/gemma.md | 5 + docs/source/en/model_doc/llama.md | 5 + docs/source/en/model_doc/mistral.md | 5 + docs/source/en/model_doc/mixtral.md | 5 + docs/source/en/model_doc/persimmon.md | 5 + docs/source/en/model_doc/qwen2.md | 5 + docs/source/en/model_doc/qwen2_moe.md | 5 + docs/source/en/model_doc/stablelm.md | 5 + docs/source/en/model_doc/starcoder2.md | 5 + src/transformers/__init__.py | 25 ++++- src/transformers/models/auto/modeling_auto.py | 9 ++ src/transformers/models/gemma/__init__.py | 2 + .../models/gemma/modeling_gemma.py | 92 ++++++++++++++++++- src/transformers/models/llama/__init__.py | 2 + .../models/llama/modeling_llama.py | 85 +++++++++++++++++ src/transformers/models/mistral/__init__.py | 2 + .../models/mistral/modeling_mistral.py | 92 ++++++++++++++++++- src/transformers/models/mixtral/__init__.py | 2 + .../models/mixtral/modeling_mixtral.py | 86 +++++++++++++++++ src/transformers/models/persimmon/__init__.py | 2 + .../models/persimmon/modeling_persimmon.py | 92 ++++++++++++++++++- src/transformers/models/qwen2/__init__.py | 2 + .../models/qwen2/modeling_qwen2.py | 92 ++++++++++++++++++- src/transformers/models/qwen2_moe/__init__.py | 2 + .../models/qwen2_moe/modeling_qwen2_moe.py | 92 ++++++++++++++++++- src/transformers/models/stablelm/__init__.py | 2 + .../models/stablelm/modeling_stablelm.py | 92 ++++++++++++++++++- .../models/starcoder2/__init__.py | 2 + .../models/starcoder2/modeling_starcoder2.py | 92 ++++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 63 +++++++++++++ tests/models/gemma/test_modeling_gemma.py | 31 ++++++- tests/models/llama/test_modeling_llama.py | 25 ++++- tests/models/mistral/test_modeling_mistral.py | 22 ++++- tests/models/mixtral/test_modeling_mixtral.py | 28 +++++- .../persimmon/test_modeling_persimmon.py | 22 ++++- tests/models/qwen2/test_modeling_qwen2.py | 24 ++++- .../qwen2_moe/test_modeling_qwen2_moe.py | 22 ++++- .../models/stablelm/test_modeling_stablelm.py | 22 ++++- .../starcoder2/test_modeling_starcoder2.py | 22 ++++- 39 files changed, 1174 insertions(+), 19 deletions(-) diff --git a/docs/source/en/model_doc/gemma.md b/docs/source/en/model_doc/gemma.md index f55995b6d8..abd077af8d 100644 --- a/docs/source/en/model_doc/gemma.md +++ b/docs/source/en/model_doc/gemma.md @@ -60,6 +60,11 @@ This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [ [[autodoc]] GemmaForSequenceClassification - forward +## GemmaForTokenClassification + +[[autodoc]] GemmaForTokenClassification + - forward + ## FlaxGemmaModel [[autodoc]] FlaxGemmaModel diff --git a/docs/source/en/model_doc/llama.md b/docs/source/en/model_doc/llama.md index 915d5ecc70..2f0eb63da0 100644 --- a/docs/source/en/model_doc/llama.md +++ b/docs/source/en/model_doc/llama.md @@ -121,6 +121,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] LlamaForQuestionAnswering - forward +## LlamaForTokenClassification + +[[autodoc]] LlamaForTokenClassification + - forward + ## FlaxLlamaModel [[autodoc]] FlaxLlamaModel diff --git a/docs/source/en/model_doc/mistral.md b/docs/source/en/model_doc/mistral.md index 0ab2142061..d4bc761060 100644 --- a/docs/source/en/model_doc/mistral.md +++ b/docs/source/en/model_doc/mistral.md @@ -203,6 +203,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] MistralForSequenceClassification - forward +## MistralForTokenClassification + +[[autodoc]] MistralForTokenClassification + - forward + ## FlaxMistralModel [[autodoc]] FlaxMistralModel diff --git a/docs/source/en/model_doc/mixtral.md b/docs/source/en/model_doc/mixtral.md index 942b040c3f..b93acdec58 100644 --- a/docs/source/en/model_doc/mixtral.md +++ b/docs/source/en/model_doc/mixtral.md @@ -204,3 +204,8 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] MixtralForSequenceClassification - forward + +## MixtralForTokenClassification + +[[autodoc]] MixtralForTokenClassification + - forward diff --git a/docs/source/en/model_doc/persimmon.md b/docs/source/en/model_doc/persimmon.md index fe9e66a0b7..7a105ac554 100644 --- a/docs/source/en/model_doc/persimmon.md +++ b/docs/source/en/model_doc/persimmon.md @@ -96,3 +96,8 @@ The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece. T [[autodoc]] PersimmonForSequenceClassification - forward + +## PersimmonForTokenClassification + +[[autodoc]] PersimmonForTokenClassification + - forward diff --git a/docs/source/en/model_doc/qwen2.md b/docs/source/en/model_doc/qwen2.md index 5f9e5dba22..ac0e25e02c 100644 --- a/docs/source/en/model_doc/qwen2.md +++ b/docs/source/en/model_doc/qwen2.md @@ -80,3 +80,8 @@ In the following, we demonstrate how to use `Qwen2-7B-Chat-beta` for the inferen [[autodoc]] Qwen2ForSequenceClassification - forward + +## Qwen2ForTokenClassification + +[[autodoc]] Qwen2ForTokenClassification + - forward diff --git a/docs/source/en/model_doc/qwen2_moe.md b/docs/source/en/model_doc/qwen2_moe.md index 8a546c4016..9c6dc80beb 100644 --- a/docs/source/en/model_doc/qwen2_moe.md +++ b/docs/source/en/model_doc/qwen2_moe.md @@ -75,3 +75,8 @@ In the following, we demonstrate how to use `Qwen1.5-MoE-A2.7B-Chat` for the inf [[autodoc]] Qwen2MoeForSequenceClassification - forward + +## Qwen2MoeForTokenClassification + +[[autodoc]] Qwen2MoeForTokenClassification + - forward diff --git a/docs/source/en/model_doc/stablelm.md b/docs/source/en/model_doc/stablelm.md index 6a50995ca0..09c0e5855c 100644 --- a/docs/source/en/model_doc/stablelm.md +++ b/docs/source/en/model_doc/stablelm.md @@ -104,3 +104,8 @@ Now, to run the model with Flash Attention 2, refer to the snippet below: [[autodoc]] StableLmForSequenceClassification - forward + +## StableLmForTokenClassification + +[[autodoc]] StableLmForTokenClassification + - forward diff --git a/docs/source/en/model_doc/starcoder2.md b/docs/source/en/model_doc/starcoder2.md index 9e2e547b8c..1d107b3855 100644 --- a/docs/source/en/model_doc/starcoder2.md +++ b/docs/source/en/model_doc/starcoder2.md @@ -66,3 +66,8 @@ These ready-to-use checkpoints can be downloaded and used via the HuggingFace Hu [[autodoc]] Starcoder2ForSequenceClassification - forward + +## Starcoder2ForTokenClassification + +[[autodoc]] Starcoder2ForTokenClassification + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 18faabc807..4255e30379 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2031,6 +2031,7 @@ else: [ "GemmaForCausalLM", "GemmaForSequenceClassification", + "GemmaForTokenClassification", "GemmaModel", "GemmaPreTrainedModel", ] @@ -2288,6 +2289,7 @@ else: "LlamaForCausalLM", "LlamaForQuestionAnswering", "LlamaForSequenceClassification", + "LlamaForTokenClassification", "LlamaModel", "LlamaPreTrainedModel", ] @@ -2435,12 +2437,19 @@ else: [ "MistralForCausalLM", "MistralForSequenceClassification", + "MistralForTokenClassification", "MistralModel", "MistralPreTrainedModel", ] ) _import_structure["models.mixtral"].extend( - ["MixtralForCausalLM", "MixtralForSequenceClassification", "MixtralModel", "MixtralPreTrainedModel"] + [ + "MixtralForCausalLM", + "MixtralForSequenceClassification", + "MixtralForTokenClassification", + "MixtralModel", + "MixtralPreTrainedModel", + ] ) _import_structure["models.mobilebert"].extend( [ @@ -2714,6 +2723,7 @@ else: [ "PersimmonForCausalLM", "PersimmonForSequenceClassification", + "PersimmonForTokenClassification", "PersimmonModel", "PersimmonPreTrainedModel", ] @@ -2810,6 +2820,7 @@ else: [ "Qwen2ForCausalLM", "Qwen2ForSequenceClassification", + "Qwen2ForTokenClassification", "Qwen2Model", "Qwen2PreTrainedModel", ] @@ -2818,6 +2829,7 @@ else: [ "Qwen2MoeForCausalLM", "Qwen2MoeForSequenceClassification", + "Qwen2MoeForTokenClassification", "Qwen2MoeModel", "Qwen2MoePreTrainedModel", ] @@ -3066,6 +3078,7 @@ else: [ "StableLmForCausalLM", "StableLmForSequenceClassification", + "StableLmForTokenClassification", "StableLmModel", "StableLmPreTrainedModel", ] @@ -3074,6 +3087,7 @@ else: [ "Starcoder2ForCausalLM", "Starcoder2ForSequenceClassification", + "Starcoder2ForTokenClassification", "Starcoder2Model", "Starcoder2PreTrainedModel", ] @@ -6489,6 +6503,7 @@ if TYPE_CHECKING: from .models.gemma import ( GemmaForCausalLM, GemmaForSequenceClassification, + GemmaForTokenClassification, GemmaModel, GemmaPreTrainedModel, ) @@ -6686,6 +6701,7 @@ if TYPE_CHECKING: LlamaForCausalLM, LlamaForQuestionAnswering, LlamaForSequenceClassification, + LlamaForTokenClassification, LlamaModel, LlamaPreTrainedModel, ) @@ -6801,12 +6817,14 @@ if TYPE_CHECKING: from .models.mistral import ( MistralForCausalLM, MistralForSequenceClassification, + MistralForTokenClassification, MistralModel, MistralPreTrainedModel, ) from .models.mixtral import ( MixtralForCausalLM, MixtralForSequenceClassification, + MixtralForTokenClassification, MixtralModel, MixtralPreTrainedModel, ) @@ -7025,6 +7043,7 @@ if TYPE_CHECKING: from .models.persimmon import ( PersimmonForCausalLM, PersimmonForSequenceClassification, + PersimmonForTokenClassification, PersimmonModel, PersimmonPreTrainedModel, ) @@ -7099,12 +7118,14 @@ if TYPE_CHECKING: from .models.qwen2 import ( Qwen2ForCausalLM, Qwen2ForSequenceClassification, + Qwen2ForTokenClassification, Qwen2Model, Qwen2PreTrainedModel, ) from .models.qwen2_moe import ( Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, + Qwen2MoeForTokenClassification, Qwen2MoeModel, Qwen2MoePreTrainedModel, ) @@ -7306,12 +7327,14 @@ if TYPE_CHECKING: from .models.stablelm import ( StableLmForCausalLM, StableLmForSequenceClassification, + StableLmForTokenClassification, StableLmModel, StableLmPreTrainedModel, ) from .models.starcoder2 import ( Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, + Starcoder2ForTokenClassification, Starcoder2Model, Starcoder2PreTrainedModel, ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 7825a0217f..c16ab83bdc 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1038,6 +1038,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("flaubert", "FlaubertForTokenClassification"), ("fnet", "FNetForTokenClassification"), ("funnel", "FunnelForTokenClassification"), + ("gemma", "GemmaForTokenClassification"), ("gpt-sw3", "GPT2ForTokenClassification"), ("gpt2", "GPT2ForTokenClassification"), ("gpt_bigcode", "GPTBigCodeForTokenClassification"), @@ -1048,11 +1049,14 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("layoutlmv2", "LayoutLMv2ForTokenClassification"), ("layoutlmv3", "LayoutLMv3ForTokenClassification"), ("lilt", "LiltForTokenClassification"), + ("llama", "LlamaForTokenClassification"), ("longformer", "LongformerForTokenClassification"), ("luke", "LukeForTokenClassification"), ("markuplm", "MarkupLMForTokenClassification"), ("mega", "MegaForTokenClassification"), ("megatron-bert", "MegatronBertForTokenClassification"), + ("mistral", "MistralForTokenClassification"), + ("mixtral", "MixtralForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"), ("mpnet", "MPNetForTokenClassification"), ("mpt", "MptForTokenClassification"), @@ -1060,15 +1064,20 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("mt5", "MT5ForTokenClassification"), ("nezha", "NezhaForTokenClassification"), ("nystromformer", "NystromformerForTokenClassification"), + ("persimmon", "PersimmonForTokenClassification"), ("phi", "PhiForTokenClassification"), ("phi3", "Phi3ForTokenClassification"), ("qdqbert", "QDQBertForTokenClassification"), + ("qwen2", "Qwen2ForTokenClassification"), + ("qwen2_moe", "Qwen2MoeForTokenClassification"), ("rembert", "RemBertForTokenClassification"), ("roberta", "RobertaForTokenClassification"), ("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"), ("roc_bert", "RoCBertForTokenClassification"), ("roformer", "RoFormerForTokenClassification"), ("squeezebert", "SqueezeBertForTokenClassification"), + ("stablelm", "StableLmForTokenClassification"), + ("starcoder2", "Starcoder2ForTokenClassification"), ("t5", "T5ForTokenClassification"), ("umt5", "UMT5ForTokenClassification"), ("xlm", "XLMForTokenClassification"), diff --git a/src/transformers/models/gemma/__init__.py b/src/transformers/models/gemma/__init__.py index 1c832e9051..1aafae6e88 100644 --- a/src/transformers/models/gemma/__init__.py +++ b/src/transformers/models/gemma/__init__.py @@ -55,6 +55,7 @@ else: "GemmaModel", "GemmaPreTrainedModel", "GemmaForSequenceClassification", + "GemmaForTokenClassification", ] try: @@ -98,6 +99,7 @@ if TYPE_CHECKING: from .modeling_gemma import ( GemmaForCausalLM, GemmaForSequenceClassification, + GemmaForTokenClassification, GemmaModel, GemmaPreTrainedModel, ) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index a3fc4b171d..ec88074ad5 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -30,7 +30,12 @@ from ...modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_causal_attention_mask, ) -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 from ...utils import ( @@ -1346,3 +1351,88 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The Gemma Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + GEMMA_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Gemma, LLAMA->GEMMA +class GemmaForTokenClassification(GemmaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = GemmaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/llama/__init__.py b/src/transformers/models/llama/__init__.py index 4b8a33118c..3f6461c4c0 100644 --- a/src/transformers/models/llama/__init__.py +++ b/src/transformers/models/llama/__init__.py @@ -55,6 +55,7 @@ else: "LlamaPreTrainedModel", "LlamaForSequenceClassification", "LlamaForQuestionAnswering", + "LlamaForTokenClassification", ] try: @@ -95,6 +96,7 @@ if TYPE_CHECKING: LlamaForCausalLM, LlamaForQuestionAnswering, LlamaForSequenceClassification, + LlamaForTokenClassification, LlamaModel, LlamaPreTrainedModel, ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index ba02a7fc77..56a83c178c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -36,6 +36,7 @@ from ...modeling_outputs import ( CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, + TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -1516,3 +1517,87 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel): hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/mistral/__init__.py b/src/transformers/models/mistral/__init__.py index dc0b85980f..abf1e32a4b 100644 --- a/src/transformers/models/mistral/__init__.py +++ b/src/transformers/models/mistral/__init__.py @@ -32,6 +32,7 @@ else: "MistralModel", "MistralPreTrainedModel", "MistralForSequenceClassification", + "MistralForTokenClassification", ] try: @@ -59,6 +60,7 @@ if TYPE_CHECKING: from .modeling_mistral import ( MistralForCausalLM, MistralForSequenceClassification, + MistralForTokenClassification, MistralModel, MistralPreTrainedModel, ) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index ccfa034ae6..cc5d8f0862 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -1366,3 +1371,88 @@ class MistralForSequenceClassification(MistralPreTrainedModel): hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + MISTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mistral, LLAMA->MISTRAL +class MistralForTokenClassification(MistralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MistralModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/mixtral/__init__.py b/src/transformers/models/mixtral/__init__.py index 7b8f061dac..b124d41dfb 100644 --- a/src/transformers/models/mixtral/__init__.py +++ b/src/transformers/models/mixtral/__init__.py @@ -36,6 +36,7 @@ else: "MixtralModel", "MixtralPreTrainedModel", "MixtralForSequenceClassification", + "MixtralForTokenClassification", ] @@ -51,6 +52,7 @@ if TYPE_CHECKING: from .modeling_mixtral import ( MixtralForCausalLM, MixtralForSequenceClassification, + MixtralForTokenClassification, MixtralModel, MixtralPreTrainedModel, ) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 19e4f15718..7bafe49905 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -38,6 +38,7 @@ from ...modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast, + TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import is_torch_greater_or_equal_than_1_13 @@ -1582,3 +1583,88 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel): hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The Mixtral Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL +class MixtralForTokenClassification(MixtralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MixtralModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/persimmon/__init__.py b/src/transformers/models/persimmon/__init__.py index 75bc218a29..e1f24ca1b7 100644 --- a/src/transformers/models/persimmon/__init__.py +++ b/src/transformers/models/persimmon/__init__.py @@ -36,6 +36,7 @@ else: "PersimmonModel", "PersimmonPreTrainedModel", "PersimmonForSequenceClassification", + "PersimmonForTokenClassification", ] @@ -51,6 +52,7 @@ if TYPE_CHECKING: from .modeling_persimmon import ( PersimmonForCausalLM, PersimmonForSequenceClassification, + PersimmonForTokenClassification, PersimmonModel, PersimmonPreTrainedModel, ) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 8d4ad53207..75ba7163ba 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -29,7 +29,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_persimmon import PersimmonConfig @@ -1011,3 +1016,88 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel): hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The Persimmon Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + PERSIMMON_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Persimmon, LLAMA->PERSIMMON +class PersimmonForTokenClassification(PersimmonPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = PersimmonModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/qwen2/__init__.py b/src/transformers/models/qwen2/__init__.py index 3409f28214..35df37e91a 100644 --- a/src/transformers/models/qwen2/__init__.py +++ b/src/transformers/models/qwen2/__init__.py @@ -45,6 +45,7 @@ else: "Qwen2Model", "Qwen2PreTrainedModel", "Qwen2ForSequenceClassification", + "Qwen2ForTokenClassification", ] @@ -69,6 +70,7 @@ if TYPE_CHECKING: from .modeling_qwen2 import ( Qwen2ForCausalLM, Qwen2ForSequenceClassification, + Qwen2ForTokenClassification, Qwen2Model, Qwen2PreTrainedModel, ) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 76b975ce9b..c9b8bd8fde 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -1375,3 +1380,88 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + QWEN2_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2 +class Qwen2ForTokenClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/qwen2_moe/__init__.py b/src/transformers/models/qwen2_moe/__init__.py index fb12383278..e2b73ba2d1 100644 --- a/src/transformers/models/qwen2_moe/__init__.py +++ b/src/transformers/models/qwen2_moe/__init__.py @@ -36,6 +36,7 @@ else: "Qwen2MoeModel", "Qwen2MoePreTrainedModel", "Qwen2MoeForSequenceClassification", + "Qwen2MoeForTokenClassification", ] @@ -51,6 +52,7 @@ if TYPE_CHECKING: from .modeling_qwen2_moe import ( Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, + Qwen2MoeForTokenClassification, Qwen2MoeModel, Qwen2MoePreTrainedModel, ) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 914a2dbdf4..f8d1bf6bb4 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -32,7 +32,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -1571,3 +1576,88 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The Qwen2MoE Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + QWEN2MOE_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE +class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2MoeModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/stablelm/__init__.py b/src/transformers/models/stablelm/__init__.py index 7fc3a6857f..c00c045f7f 100644 --- a/src/transformers/models/stablelm/__init__.py +++ b/src/transformers/models/stablelm/__init__.py @@ -36,6 +36,7 @@ else: "StableLmModel", "StableLmPreTrainedModel", "StableLmForSequenceClassification", + "StableLmForTokenClassification", ] @@ -51,6 +52,7 @@ if TYPE_CHECKING: from .modeling_stablelm import ( StableLmForCausalLM, StableLmForSequenceClassification, + StableLmForTokenClassification, StableLmModel, StableLmPreTrainedModel, ) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index bc133ffb3d..91a6e83a8b 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -30,7 +30,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -1383,3 +1388,88 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel): hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The StableLm Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + STABLELM_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->StableLm, LLAMA->STABLELM +class StableLmForTokenClassification(StableLmPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = StableLmModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/starcoder2/__init__.py b/src/transformers/models/starcoder2/__init__.py index 1eb195fde1..d9dc2cd1e5 100644 --- a/src/transformers/models/starcoder2/__init__.py +++ b/src/transformers/models/starcoder2/__init__.py @@ -36,6 +36,7 @@ else: "Starcoder2Model", "Starcoder2PreTrainedModel", "Starcoder2ForSequenceClassification", + "Starcoder2ForTokenClassification", ] @@ -51,6 +52,7 @@ if TYPE_CHECKING: from .modeling_starcoder2 import ( Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, + Starcoder2ForTokenClassification, Starcoder2Model, Starcoder2PreTrainedModel, ) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 24d285e0b3..ae64b23aa6 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -31,7 +31,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -1357,3 +1362,88 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel): hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The Starcoder2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + STARCODER2_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Starcoder2, LLAMA->STARCODER2 +class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Starcoder2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 681f858556..5e00230aed 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3692,6 +3692,13 @@ class GemmaForSequenceClassification(metaclass=DummyObject): requires_backends(self, ["torch"]) +class GemmaForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class GemmaModel(metaclass=DummyObject): _backends = ["torch"] @@ -4642,6 +4649,13 @@ class LlamaForSequenceClassification(metaclass=DummyObject): requires_backends(self, ["torch"]) +class LlamaForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class LlamaModel(metaclass=DummyObject): _backends = ["torch"] @@ -5237,6 +5251,13 @@ class MistralForSequenceClassification(metaclass=DummyObject): requires_backends(self, ["torch"]) +class MistralForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MistralModel(metaclass=DummyObject): _backends = ["torch"] @@ -5265,6 +5286,13 @@ class MixtralForSequenceClassification(metaclass=DummyObject): requires_backends(self, ["torch"]) +class MixtralForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MixtralModel(metaclass=DummyObject): _backends = ["torch"] @@ -6373,6 +6401,13 @@ class PersimmonForSequenceClassification(metaclass=DummyObject): requires_backends(self, ["torch"]) +class PersimmonForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class PersimmonModel(metaclass=DummyObject): _backends = ["torch"] @@ -6734,6 +6769,13 @@ class Qwen2ForSequenceClassification(metaclass=DummyObject): requires_backends(self, ["torch"]) +class Qwen2ForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Qwen2Model(metaclass=DummyObject): _backends = ["torch"] @@ -6762,6 +6804,13 @@ class Qwen2MoeForSequenceClassification(metaclass=DummyObject): requires_backends(self, ["torch"]) +class Qwen2MoeForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Qwen2MoeModel(metaclass=DummyObject): _backends = ["torch"] @@ -7793,6 +7842,13 @@ class StableLmForSequenceClassification(metaclass=DummyObject): requires_backends(self, ["torch"]) +class StableLmForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class StableLmModel(metaclass=DummyObject): _backends = ["torch"] @@ -7821,6 +7877,13 @@ class Starcoder2ForSequenceClassification(metaclass=DummyObject): requires_backends(self, ["torch"]) +class Starcoder2ForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Starcoder2Model(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 7cc2ebe7bd..8281f59993 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -41,7 +41,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): import torch - from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel, GemmaTokenizer + from transformers import ( + GemmaForCausalLM, + GemmaForSequenceClassification, + GemmaForTokenClassification, + GemmaModel, + GemmaTokenizer, + ) class GemmaModelTester: @@ -284,12 +290,17 @@ class GemmaModelTester: @require_torch class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification) if is_torch_available() else () + all_model_classes = ( + (GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification) + if is_torch_available() + else () + ) all_generative_model_classes = (GemmaForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": GemmaModel, "text-classification": GemmaForSequenceClassification, + "token-classification": GemmaForTokenClassification, "text-generation": GemmaForCausalLM, "zero-shot": GemmaForSequenceClassification, } @@ -370,6 +381,22 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Gemma,llama->Gemma + def test_Gemma_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = GemmaForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Gemma buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 9fcdcc6d7e..58269d62e0 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -47,6 +47,7 @@ if is_torch_available(): LlamaForCausalLM, LlamaForQuestionAnswering, LlamaForSequenceClassification, + LlamaForTokenClassification, LlamaModel, LlamaTokenizer, ) @@ -286,7 +287,13 @@ class LlamaModelTester: @require_torch class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForQuestionAnswering) + ( + LlamaModel, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaForQuestionAnswering, + LlamaForTokenClassification, + ) if is_torch_available() else () ) @@ -298,6 +305,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi "text-generation": LlamaForCausalLM, "zero-shot": LlamaForSequenceClassification, "question-answering": LlamaForQuestionAnswering, + "token-classification": LlamaForTokenClassification, } if is_torch_available() else {} @@ -370,6 +378,21 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + def test_llama_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = LlamaForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Llama buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index bbc36c050e..10dcb3bf4d 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -46,6 +46,7 @@ if is_torch_available(): from transformers import ( MistralForCausalLM, MistralForSequenceClassification, + MistralForTokenClassification, MistralModel, ) @@ -288,13 +289,16 @@ class MistralModelTester: @require_torch class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (MistralModel, MistralForCausalLM, MistralForSequenceClassification) if is_torch_available() else () + (MistralModel, MistralForCausalLM, MistralForSequenceClassification, MistralForTokenClassification) + if is_torch_available() + else () ) all_generative_model_classes = (MistralForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": MistralModel, "text-classification": MistralForSequenceClassification, + "token-classification": MistralForTokenClassification, "text-generation": MistralForCausalLM, "zero-shot": MistralForSequenceClassification, } @@ -376,6 +380,22 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mistral,llama->Mistral + def test_Mistral_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = MistralForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Mistral buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 0d92595d8c..0972207fce 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -40,7 +40,12 @@ from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): import torch - from transformers import MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel + from transformers import ( + MixtralForCausalLM, + MixtralForSequenceClassification, + MixtralForTokenClassification, + MixtralModel, + ) class MixtralModelTester: @@ -287,13 +292,16 @@ class MixtralModelTester: # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification) if is_torch_available() else () + (MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification, MixtralForTokenClassification) + if is_torch_available() + else () ) all_generative_model_classes = (MixtralForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": MixtralModel, "text-classification": MixtralForSequenceClassification, + "token-classification": MixtralForTokenClassification, "text-generation": MixtralForCausalLM, "zero-shot": MixtralForSequenceClassification, } @@ -375,6 +383,22 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mixtral,llama->Mixtral + def test_Mixtral_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = MixtralForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Mixtral buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index 86a69d774f..46a650c55a 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -44,6 +44,7 @@ if is_torch_available(): AutoTokenizer, PersimmonForCausalLM, PersimmonForSequenceClassification, + PersimmonForTokenClassification, PersimmonModel, ) from transformers.models.persimmon.modeling_persimmon import ( @@ -283,12 +284,15 @@ class PersimmonModelTester: @require_torch class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification) if is_torch_available() else () + (PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification, PersimmonForTokenClassification) + if is_torch_available() + else () ) pipeline_model_mapping = ( { "feature-extraction": PersimmonModel, "text-classification": PersimmonForSequenceClassification, + "token-classification": PersimmonForTokenClassification, # TODO (ydshieh): check why these two fail. Fix them or skip them in a better way. # "text-generation": PersimmonForCausalLM, # "zero-shot": PersimmonForSequenceClassification, @@ -365,6 +369,22 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Persimmon,llama->persimmon + def test_persimmon_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = PersimmonForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Persimmon buffers include complex numbers, which breaks this test") # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base def test_save_load_fast_init_from_base(self): diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index f4e88a97f0..54718c4303 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -45,6 +45,7 @@ if is_torch_available(): from transformers import ( Qwen2ForCausalLM, Qwen2ForSequenceClassification, + Qwen2ForTokenClassification, Qwen2Model, ) @@ -299,12 +300,17 @@ class Qwen2ModelTester: @require_torch # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2 class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification) if is_torch_available() else () + all_model_classes = ( + (Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification, Qwen2ForTokenClassification) + if is_torch_available() + else () + ) all_generative_model_classes = (Qwen2ForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": Qwen2Model, "text-classification": Qwen2ForSequenceClassification, + "token-classification": Qwen2ForTokenClassification, "text-generation": Qwen2ForCausalLM, "zero-shot": Qwen2ForSequenceClassification, } @@ -387,6 +393,22 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2,llama->Qwen2 + def test_Qwen2_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = Qwen2ForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Qwen2 buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index f0818e680d..48c6ccf78a 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -45,6 +45,7 @@ if is_torch_available(): from transformers import ( Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, + Qwen2MoeForTokenClassification, Qwen2MoeModel, ) @@ -327,13 +328,16 @@ class Qwen2MoeModelTester: # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification) if is_torch_available() else () + (Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, Qwen2MoeForTokenClassification) + if is_torch_available() + else () ) all_generative_model_classes = (Qwen2MoeForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": Qwen2MoeModel, "text-classification": Qwen2MoeForSequenceClassification, + "token-classification": Qwen2MoeForTokenClassification, "text-generation": Qwen2MoeForCausalLM, "zero-shot": Qwen2MoeForSequenceClassification, } @@ -414,6 +418,22 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2Moe,llama->Qwen2Moe + def test_Qwen2Moe_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = Qwen2MoeForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Qwen2Moe buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index f5e74ead9b..083f928612 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -43,6 +43,7 @@ if is_torch_available(): AutoTokenizer, StableLmForCausalLM, StableLmForSequenceClassification, + StableLmForTokenClassification, StableLmModel, ) from transformers.models.stablelm.modeling_stablelm import ( @@ -287,12 +288,15 @@ class StableLmModelTester: # Copied from transformers.tests.persimmon.test_modeling_persimmon.PersimmonModelTest with Persimmon -> StableLm class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification) if is_torch_available() else () + (StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification, StableLmForTokenClassification) + if is_torch_available() + else () ) pipeline_model_mapping = ( { "feature-extraction": StableLmModel, "text-classification": StableLmForSequenceClassification, + "token-classification": StableLmForTokenClassification, # TODO (ydshieh): check why these two fail. Fix them or skip them in a better way. # "text-generation": StableLmForCausalLM, # "zero-shot": StableLmForSequenceClassification, @@ -356,6 +360,22 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->StableLm,llama->stablelm + def test_stablelm_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = StableLmForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @parameterized.expand([("linear",), ("dynamic",)]) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm def test_model_rope_scaling_from_config(self, scaling_type): diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py index 95f604d06b..faba4d254b 100644 --- a/tests/models/starcoder2/test_modeling_starcoder2.py +++ b/tests/models/starcoder2/test_modeling_starcoder2.py @@ -43,6 +43,7 @@ if is_torch_available(): AutoTokenizer, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, + Starcoder2ForTokenClassification, Starcoder2Model, ) @@ -290,13 +291,16 @@ class Starcoder2ModelTester: # Copied from transformers.tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Starcoder2 class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( - (Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification) if is_torch_available() else () + (Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, Starcoder2ForTokenClassification) + if is_torch_available() + else () ) all_generative_model_classes = (Starcoder2ForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": Starcoder2Model, "text-classification": Starcoder2ForSequenceClassification, + "token-classification": Starcoder2ForTokenClassification, "text-generation": Starcoder2ForCausalLM, "zero-shot": Starcoder2ForSequenceClassification, } @@ -370,6 +374,22 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Starcoder2,llama->Starcoder2 + def test_Starcoder2_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = Starcoder2ForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + @unittest.skip("Starcoder2 buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass