Add TokenClassification for Mistral, Mixtral and Qwen2 (#29878)
* Add MistralForTokenClassification * Add tests and docs * Add token classification for Mixtral and Qwen2 * Save llma for token classification draft * Add token classification support for Llama, Gemma, Persimmon, StableLm and StarCoder2 * Formatting * Add token classification support for Qwen2Moe model * Add dropout layer to each ForTokenClassification model * Add copied from in tests * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Propagate suggested changes * Style --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -45,6 +45,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
Qwen2MoeForCausalLM,
|
||||
Qwen2MoeForSequenceClassification,
|
||||
Qwen2MoeForTokenClassification,
|
||||
Qwen2MoeModel,
|
||||
)
|
||||
|
||||
@@ -327,13 +328,16 @@ class Qwen2MoeModelTester:
|
||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe
|
||||
class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification) if is_torch_available() else ()
|
||||
(Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, Qwen2MoeForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (Qwen2MoeForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": Qwen2MoeModel,
|
||||
"text-classification": Qwen2MoeForSequenceClassification,
|
||||
"token-classification": Qwen2MoeForTokenClassification,
|
||||
"text-generation": Qwen2MoeForCausalLM,
|
||||
"zero-shot": Qwen2MoeForSequenceClassification,
|
||||
}
|
||||
@@ -414,6 +418,22 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2Moe,llama->Qwen2Moe
|
||||
def test_Qwen2Moe_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = Qwen2MoeForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Qwen2Moe buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user