Add TokenClassification for Mistral, Mixtral and Qwen2 (#29878)
* Add MistralForTokenClassification * Add tests and docs * Add token classification for Mixtral and Qwen2 * Save llma for token classification draft * Add token classification support for Llama, Gemma, Persimmon, StableLm and StarCoder2 * Formatting * Add token classification support for Qwen2Moe model * Add dropout layer to each ForTokenClassification model * Add copied from in tests * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Propagate suggested changes * Style --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -41,7 +41,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel, GemmaTokenizer
|
||||
from transformers import (
|
||||
GemmaForCausalLM,
|
||||
GemmaForSequenceClassification,
|
||||
GemmaForTokenClassification,
|
||||
GemmaModel,
|
||||
GemmaTokenizer,
|
||||
)
|
||||
|
||||
|
||||
class GemmaModelTester:
|
||||
@@ -284,12 +290,17 @@ class GemmaModelTester:
|
||||
|
||||
@require_torch
|
||||
class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (GemmaForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": GemmaModel,
|
||||
"text-classification": GemmaForSequenceClassification,
|
||||
"token-classification": GemmaForTokenClassification,
|
||||
"text-generation": GemmaForCausalLM,
|
||||
"zero-shot": GemmaForSequenceClassification,
|
||||
}
|
||||
@@ -370,6 +381,22 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Gemma,llama->Gemma
|
||||
def test_Gemma_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = GemmaForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Gemma buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@@ -47,6 +47,7 @@ if is_torch_available():
|
||||
LlamaForCausalLM,
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForTokenClassification,
|
||||
LlamaModel,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
@@ -286,7 +287,13 @@ class LlamaModelTester:
|
||||
@require_torch
|
||||
class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForQuestionAnswering)
|
||||
(
|
||||
LlamaModel,
|
||||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForTokenClassification,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
@@ -298,6 +305,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
"text-generation": LlamaForCausalLM,
|
||||
"zero-shot": LlamaForSequenceClassification,
|
||||
"question-answering": LlamaForQuestionAnswering,
|
||||
"token-classification": LlamaForTokenClassification,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
@@ -370,6 +378,21 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
def test_llama_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = LlamaForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Llama buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@@ -46,6 +46,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
MistralForCausalLM,
|
||||
MistralForSequenceClassification,
|
||||
MistralForTokenClassification,
|
||||
MistralModel,
|
||||
)
|
||||
|
||||
@@ -288,13 +289,16 @@ class MistralModelTester:
|
||||
@require_torch
|
||||
class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(MistralModel, MistralForCausalLM, MistralForSequenceClassification) if is_torch_available() else ()
|
||||
(MistralModel, MistralForCausalLM, MistralForSequenceClassification, MistralForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (MistralForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": MistralModel,
|
||||
"text-classification": MistralForSequenceClassification,
|
||||
"token-classification": MistralForTokenClassification,
|
||||
"text-generation": MistralForCausalLM,
|
||||
"zero-shot": MistralForSequenceClassification,
|
||||
}
|
||||
@@ -376,6 +380,22 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mistral,llama->Mistral
|
||||
def test_Mistral_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = MistralForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Mistral buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@@ -40,7 +40,12 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel
|
||||
from transformers import (
|
||||
MixtralForCausalLM,
|
||||
MixtralForSequenceClassification,
|
||||
MixtralForTokenClassification,
|
||||
MixtralModel,
|
||||
)
|
||||
|
||||
|
||||
class MixtralModelTester:
|
||||
@@ -287,13 +292,16 @@ class MixtralModelTester:
|
||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral
|
||||
class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification) if is_torch_available() else ()
|
||||
(MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification, MixtralForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (MixtralForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": MixtralModel,
|
||||
"text-classification": MixtralForSequenceClassification,
|
||||
"token-classification": MixtralForTokenClassification,
|
||||
"text-generation": MixtralForCausalLM,
|
||||
"zero-shot": MixtralForSequenceClassification,
|
||||
}
|
||||
@@ -375,6 +383,22 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mixtral,llama->Mixtral
|
||||
def test_Mixtral_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = MixtralForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Mixtral buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@@ -44,6 +44,7 @@ if is_torch_available():
|
||||
AutoTokenizer,
|
||||
PersimmonForCausalLM,
|
||||
PersimmonForSequenceClassification,
|
||||
PersimmonForTokenClassification,
|
||||
PersimmonModel,
|
||||
)
|
||||
from transformers.models.persimmon.modeling_persimmon import (
|
||||
@@ -283,12 +284,15 @@ class PersimmonModelTester:
|
||||
@require_torch
|
||||
class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification) if is_torch_available() else ()
|
||||
(PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification, PersimmonForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": PersimmonModel,
|
||||
"text-classification": PersimmonForSequenceClassification,
|
||||
"token-classification": PersimmonForTokenClassification,
|
||||
# TODO (ydshieh): check why these two fail. Fix them or skip them in a better way.
|
||||
# "text-generation": PersimmonForCausalLM,
|
||||
# "zero-shot": PersimmonForSequenceClassification,
|
||||
@@ -365,6 +369,22 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Persimmon,llama->persimmon
|
||||
def test_persimmon_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = PersimmonForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Persimmon buffers include complex numbers, which breaks this test")
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
|
||||
@@ -45,6 +45,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen2ForTokenClassification,
|
||||
Qwen2Model,
|
||||
)
|
||||
|
||||
@@ -299,12 +300,17 @@ class Qwen2ModelTester:
|
||||
@require_torch
|
||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2
|
||||
class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification, Qwen2ForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (Qwen2ForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": Qwen2Model,
|
||||
"text-classification": Qwen2ForSequenceClassification,
|
||||
"token-classification": Qwen2ForTokenClassification,
|
||||
"text-generation": Qwen2ForCausalLM,
|
||||
"zero-shot": Qwen2ForSequenceClassification,
|
||||
}
|
||||
@@ -387,6 +393,22 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2,llama->Qwen2
|
||||
def test_Qwen2_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = Qwen2ForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Qwen2 buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@@ -45,6 +45,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
Qwen2MoeForCausalLM,
|
||||
Qwen2MoeForSequenceClassification,
|
||||
Qwen2MoeForTokenClassification,
|
||||
Qwen2MoeModel,
|
||||
)
|
||||
|
||||
@@ -327,13 +328,16 @@ class Qwen2MoeModelTester:
|
||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe
|
||||
class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification) if is_torch_available() else ()
|
||||
(Qwen2MoeModel, Qwen2MoeForCausalLM, Qwen2MoeForSequenceClassification, Qwen2MoeForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (Qwen2MoeForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": Qwen2MoeModel,
|
||||
"text-classification": Qwen2MoeForSequenceClassification,
|
||||
"token-classification": Qwen2MoeForTokenClassification,
|
||||
"text-generation": Qwen2MoeForCausalLM,
|
||||
"zero-shot": Qwen2MoeForSequenceClassification,
|
||||
}
|
||||
@@ -414,6 +418,22 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2Moe,llama->Qwen2Moe
|
||||
def test_Qwen2Moe_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = Qwen2MoeForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Qwen2Moe buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@@ -43,6 +43,7 @@ if is_torch_available():
|
||||
AutoTokenizer,
|
||||
StableLmForCausalLM,
|
||||
StableLmForSequenceClassification,
|
||||
StableLmForTokenClassification,
|
||||
StableLmModel,
|
||||
)
|
||||
from transformers.models.stablelm.modeling_stablelm import (
|
||||
@@ -287,12 +288,15 @@ class StableLmModelTester:
|
||||
# Copied from transformers.tests.persimmon.test_modeling_persimmon.PersimmonModelTest with Persimmon -> StableLm
|
||||
class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification) if is_torch_available() else ()
|
||||
(StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification, StableLmForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": StableLmModel,
|
||||
"text-classification": StableLmForSequenceClassification,
|
||||
"token-classification": StableLmForTokenClassification,
|
||||
# TODO (ydshieh): check why these two fail. Fix them or skip them in a better way.
|
||||
# "text-generation": StableLmForCausalLM,
|
||||
# "zero-shot": StableLmForSequenceClassification,
|
||||
@@ -356,6 +360,22 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->StableLm,llama->stablelm
|
||||
def test_stablelm_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = StableLmForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@parameterized.expand([("linear",), ("dynamic",)])
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm
|
||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
||||
|
||||
@@ -43,6 +43,7 @@ if is_torch_available():
|
||||
AutoTokenizer,
|
||||
Starcoder2ForCausalLM,
|
||||
Starcoder2ForSequenceClassification,
|
||||
Starcoder2ForTokenClassification,
|
||||
Starcoder2Model,
|
||||
)
|
||||
|
||||
@@ -290,13 +291,16 @@ class Starcoder2ModelTester:
|
||||
# Copied from transformers.tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Starcoder2
|
||||
class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification) if is_torch_available() else ()
|
||||
(Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, Starcoder2ForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (Starcoder2ForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": Starcoder2Model,
|
||||
"text-classification": Starcoder2ForSequenceClassification,
|
||||
"token-classification": Starcoder2ForTokenClassification,
|
||||
"text-generation": Starcoder2ForCausalLM,
|
||||
"zero-shot": Starcoder2ForSequenceClassification,
|
||||
}
|
||||
@@ -370,6 +374,22 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Starcoder2,llama->Starcoder2
|
||||
def test_Starcoder2_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = Starcoder2ForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip("Starcoder2 buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user