added GPTNeoXForTokenClassification (#23002)

* initial commit

* added GPTNeoXForTokenClassification

* typo

* doc
fixed extra comma that turned into a tuple

* unifying variable names
fixing forward call

* classifier_dropout is in config

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Prof. Peter Schneider-Kamp <jps@ordbogen.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
peter-sk
2023-04-27 17:08:26 +02:00
committed by GitHub
parent 1933231a0a
commit 614e191c4d
9 changed files with 130 additions and 5 deletions

View File

@@ -29,7 +29,12 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import GPTNeoXForCausalLM, GPTNeoXForSequenceClassification, GPTNeoXModel
from transformers import (
GPTNeoXForCausalLM,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXModel,
)
class GPTNeoXModelTester:
@@ -153,6 +158,14 @@ class GPTNeoXModelTester:
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_for_token_classification(self, config, input_ids, input_mask, token_labels):
config.num_labels = self.num_labels
model = GPTNeoXForTokenClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask):
config.is_decoder = True
model = GPTNeoXForCausalLM(config=config)
@@ -200,13 +213,16 @@ class GPTNeoXModelTester:
@require_torch
class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXForSequenceClassification) if is_torch_available() else ()
(GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXForSequenceClassification, GPTNeoXForTokenClassification)
if is_torch_available()
else ()
)
all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": GPTNeoXModel,
"text-classification": GPTNeoXForSequenceClassification,
"token-classification": GPTNeoXForTokenClassification,
"text-generation": GPTNeoXForCausalLM,
"zero-shot": GPTNeoXForSequenceClassification,
}
@@ -253,6 +269,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_model_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
@unittest.skip(reason="Feed forward chunking is not implemented")
def test_feed_forward_chunking(self):
pass