GPTNeoForQuestionAnswering (#23057)
* first draft - gives index error in question_answering.py * maturing * no labels * pipeline should know about QA * fixing checks * formatting * fixed docstring * initial commit * formatting * adding the class to many places * towards less unhappy checks * nearly there * Update src/transformers/models/gpt_neo/modeling_gpt_neo.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * avoid error * moving to device of star/end_logits --------- Co-authored-by: Prof. Peter Schneider-Kamp <jps@ordbogen.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -34,6 +34,7 @@ if is_torch_available():
|
||||
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
GPT2Tokenizer,
|
||||
GPTNeoForCausalLM,
|
||||
GPTNeoForQuestionAnswering,
|
||||
GPTNeoForSequenceClassification,
|
||||
GPTNeoForTokenClassification,
|
||||
GPTNeoModel,
|
||||
@@ -325,6 +326,17 @@ class GPTNeoModelTester:
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_gpt_neo_for_question_answering(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = GPTNeoForQuestionAnswering(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_gpt_neo_for_sequence_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
@@ -385,7 +397,13 @@ class GPTNeoModelTester:
|
||||
@require_torch
|
||||
class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification, GPTNeoForTokenClassification)
|
||||
(
|
||||
GPTNeoModel,
|
||||
GPTNeoForCausalLM,
|
||||
GPTNeoForQuestionAnswering,
|
||||
GPTNeoForSequenceClassification,
|
||||
GPTNeoForTokenClassification,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
@@ -393,6 +411,7 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": GPTNeoModel,
|
||||
"question-answering": GPTNeoForQuestionAnswering,
|
||||
"text-classification": GPTNeoForSequenceClassification,
|
||||
"token-classification": GPTNeoForTokenClassification,
|
||||
"text-generation": GPTNeoForCausalLM,
|
||||
@@ -438,6 +457,10 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_question_answering_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_neo_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_sequence_classification_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user