add GPTNeoXForSequenceClassification (#22671)

* add GPTNeoXForSequenceClassification

* move the labels to logits.device (ref: #22561)

* fix
This commit is contained in:
Sugawara
2023-04-11 00:52:23 +09:00
committed by GitHub
parent f74b40208d
commit 6daa9cb515
8 changed files with 175 additions and 6 deletions

View File

@@ -29,7 +29,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import GPTNeoXForCausalLM, GPTNeoXModel
from transformers import GPTNeoXForCausalLM, GPTNeoXForSequenceClassification, GPTNeoXModel
class GPTNeoXModelTester:
@@ -80,6 +80,7 @@ class GPTNeoXModelTester:
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.pad_token_id = vocab_size - 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -110,6 +111,7 @@ class GPTNeoXModelTester:
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
)
def prepare_config_and_inputs_for_decoder(self):
@@ -142,6 +144,15 @@ class GPTNeoXModelTester:
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_for_sequence_classification(self, config, input_ids, input_mask, token_labels):
config.num_labels = self.num_labels
model = GPTNeoXForSequenceClassification(config)
model.to(torch_device)
model.eval()
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask):
config.is_decoder = True
model = GPTNeoXForCausalLM(config=config)
@@ -188,10 +199,19 @@ class GPTNeoXModelTester:
@require_torch
class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GPTNeoXModel, GPTNeoXForCausalLM) if is_torch_available() else ()
all_model_classes = (
(GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXForSequenceClassification) if is_torch_available() else ()
)
all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": GPTNeoXModel, "text-generation": GPTNeoXForCausalLM} if is_torch_available() else {}
{
"feature-extraction": GPTNeoXModel,
"text-classification": GPTNeoXForSequenceClassification,
"text-generation": GPTNeoXForCausalLM,
"zero-shot": GPTNeoXForSequenceClassification,
}
if is_torch_available()
else {}
)
test_pruning = False
test_missing_keys = False
@@ -229,6 +249,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
def test_model_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
@unittest.skip(reason="Feed forward chunking is not implemented")
def test_feed_forward_chunking(self):
pass