added GPTNeoForTokenClassification (#22908)
* added GPTNeoForTokenClassification * add to top-level init * fixup * test * more fixup * add to gpt_neo.mdx * repo consistency * dummy copy * fix copies * optax >= 0.1.5 assumes jax.Array exists - which it doesn't for jax <= 0.3.6 * merge with main made this superfluous * added classifier_dropout * remove legacy code * removed fmt:on/off removed expected_outputs * doc style fix * classifier_dropout is always in config --------- Co-authored-by: Prof. Peter Schneider-Kamp <jps@ordbogen.com>
This commit is contained in:
@@ -35,6 +35,7 @@ if is_torch_available():
|
||||
GPT2Tokenizer,
|
||||
GPTNeoForCausalLM,
|
||||
GPTNeoForSequenceClassification,
|
||||
GPTNeoForTokenClassification,
|
||||
GPTNeoModel,
|
||||
)
|
||||
|
||||
@@ -334,6 +335,16 @@ class GPTNeoModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_gpt_neo_for_token_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = GPTNeoForTokenClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_forward_and_backwards(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
|
||||
):
|
||||
@@ -374,13 +385,16 @@ class GPTNeoModelTester:
|
||||
@require_torch
|
||||
class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else ()
|
||||
(GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification, GPTNeoForTokenClassification)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": GPTNeoModel,
|
||||
"text-classification": GPTNeoForSequenceClassification,
|
||||
"token-classification": GPTNeoForTokenClassification,
|
||||
"text-generation": GPTNeoForCausalLM,
|
||||
"zero-shot": GPTNeoForSequenceClassification,
|
||||
}
|
||||
@@ -428,6 +442,10 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_token_classification_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_neo_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_gradient_checkpointing(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
|
||||
|
||||
Reference in New Issue
Block a user