Add SequenceClassification and MultipleChoice TF models to Electra (#6227)
* Add SequenceClassification and MultipleChoice TF models to Electra * Apply style * Add summary_proj_to_labels to Electra config * Finally mirroring the PT version of these models * Apply style * Fix Electra test
This commit is contained in:
@@ -24,10 +24,14 @@ from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers.modeling_tf_electra import (
|
||||
TFElectraModel,
|
||||
TFElectraForMaskedLM,
|
||||
TFElectraForMultipleChoice,
|
||||
TFElectraForPreTraining,
|
||||
TFElectraForSequenceClassification,
|
||||
TFElectraForTokenClassification,
|
||||
TFElectraForQuestionAnswering,
|
||||
)
|
||||
@@ -138,6 +142,35 @@ class TFElectraModelTester:
|
||||
}
|
||||
self.parent.assertListEqual(list(result["prediction_scores"].shape), [self.batch_size, self.seq_length])
|
||||
|
||||
def create_and_check_electra_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFElectraForSequenceClassification(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
(logits,) = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||
|
||||
def create_and_check_electra_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = TFElectraForMultipleChoice(config=config)
|
||||
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
|
||||
inputs = {
|
||||
"input_ids": multiple_choice_inputs_ids,
|
||||
"attention_mask": multiple_choice_input_mask,
|
||||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
(logits,) = model(inputs)
|
||||
result = {"logits": logits.numpy()}
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
|
||||
def create_and_check_electra_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
@@ -210,6 +243,14 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_token_classification(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user