Add LlamaForSequenceClassification (#22209)
* Add LlamaForSequenceClassification * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Add docstring * Add test * Add input embedding getter and setter * Remove dead code --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -27,7 +27,7 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attenti
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import LlamaForCausalLM, LlamaModel
|
||||
from transformers import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
|
||||
|
||||
|
||||
class LlamaModelTester:
|
||||
@@ -255,14 +255,7 @@ class LlamaModelTester:
|
||||
|
||||
@require_torch
|
||||
class LlamaModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
LlamaModel,
|
||||
LlamaForCausalLM,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_model_classes = (LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification) if is_torch_available() else ()
|
||||
all_generative_model_classes = (LlamaForCausalLM,) if is_torch_available() else ()
|
||||
test_headmasking = False
|
||||
|
||||
@@ -283,6 +276,46 @@ class LlamaModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_llama_sequence_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
model = LlamaForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
def test_llama_sequence_classification_model_for_single_label(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
config.problem_type = "single_label_classification"
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
model = LlamaForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
def test_llama_sequence_classification_model_for_multi_label(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
config.problem_type = "multi_label_classification"
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor(
|
||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
||||
).to(torch.float)
|
||||
model = LlamaForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
@unittest.skip("LLaMA does not support head pruning.")
|
||||
def test_head_pruning(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user