GPTNeoXForQuestionAnswering (#23059)

* first draft - gives index error in question_answering.py

* maturing

* no labels

* pipeline should know about QA

* fixing checks

* formatting

* fixed docstring

* initial commit

* formatting

* adding the class to many places

* towards less unhappy checks

* nearly there

* and gpt neox for qa

* use right model

* forgot this one

* base_model_prefix is "gpt_neox" for GPTNeoX* models

* unnecessary stuff

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* format

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* removed gpt2 stuff

---------

Co-authored-by: Prof. Peter Schneider-Kamp <jps@ordbogen.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
peter-sk
2023-05-04 16:15:15 +02:00
committed by GitHub
parent 510ad0a8b8
commit 83b38fbea8
9 changed files with 142 additions and 4 deletions

View File

@@ -31,6 +31,7 @@ if is_torch_available():
from transformers import (
GPTNeoXForCausalLM,
GPTNeoXForQuestionAnswering,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXModel,
@@ -149,6 +150,15 @@ class GPTNeoXModelTester:
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_for_question_answering(self, config, input_ids, input_mask, token_labels):
config.num_labels = self.num_labels
model = GPTNeoXForQuestionAnswering(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_for_sequence_classification(self, config, input_ids, input_mask, token_labels):
config.num_labels = self.num_labels
model = GPTNeoXForSequenceClassification(config)
@@ -213,7 +223,13 @@ class GPTNeoXModelTester:
@require_torch
class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXForSequenceClassification, GPTNeoXForTokenClassification)
(
GPTNeoXModel,
GPTNeoXForCausalLM,
GPTNeoXForQuestionAnswering,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
)
if is_torch_available()
else ()
)
@@ -221,6 +237,7 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
pipeline_model_mapping = (
{
"feature-extraction": GPTNeoXModel,
"question-answering": GPTNeoXForQuestionAnswering,
"text-classification": GPTNeoXForSequenceClassification,
"token-classification": GPTNeoXForTokenClassification,
"text-generation": GPTNeoXForCausalLM,
@@ -265,6 +282,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
def test_model_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_model_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)