Add TokenClassification for Mistral, Mixtral and Qwen2 (#29878)

* Add MistralForTokenClassification

* Add tests and docs

* Add token classification for Mixtral and Qwen2

* Save llma for token classification draft

* Add token classification support for Llama, Gemma, Persimmon, StableLm and StarCoder2

* Formatting

* Add token classification support for Qwen2Moe model

* Add dropout layer to each ForTokenClassification model

* Add copied from in tests

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Propagate suggested changes

* Style

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Joseph Enguehard
2024-05-20 09:06:57 +01:00
committed by GitHub
parent 481a957814
commit 07bf2dff78
39 changed files with 1174 additions and 19 deletions

View File

@@ -40,7 +40,12 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel
from transformers import (
MixtralForCausalLM,
MixtralForSequenceClassification,
MixtralForTokenClassification,
MixtralModel,
)
class MixtralModelTester:
@@ -287,13 +292,16 @@ class MixtralModelTester:
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral
class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification) if is_torch_available() else ()
(MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification, MixtralForTokenClassification)
if is_torch_available()
else ()
)
all_generative_model_classes = (MixtralForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": MixtralModel,
"text-classification": MixtralForSequenceClassification,
"token-classification": MixtralForTokenClassification,
"text-generation": MixtralForCausalLM,
"zero-shot": MixtralForSequenceClassification,
}
@@ -375,6 +383,22 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mixtral,llama->Mixtral
def test_Mixtral_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = MixtralForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Mixtral buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self):
pass