Fix the tests for Electra (#6284)
* Fix the tests for Electra * Apply style
This commit is contained in:
@@ -857,7 +857,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.electra = ElectraModel(config)
|
self.electra = ElectraModel(config)
|
||||||
self.summary = SequenceSummary(config)
|
self.sequence_summary = SequenceSummary(config)
|
||||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
@@ -915,7 +915,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
|
|||||||
|
|
||||||
sequence_output = discriminator_hidden_states[0]
|
sequence_output = discriminator_hidden_states[0]
|
||||||
|
|
||||||
pooled_output = self.summary(sequence_output)
|
pooled_output = self.sequence_summary(sequence_output)
|
||||||
logits = self.classifier(pooled_output)
|
logits = self.classifier(pooled_output)
|
||||||
reshaped_logits = logits.view(-1, num_choices)
|
reshaped_logits = logits.view(-1, num_choices)
|
||||||
|
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ class TFElectraModelTester:
|
|||||||
self.num_labels = 3
|
self.num_labels = 3
|
||||||
self.num_choices = 4
|
self.num_choices = 4
|
||||||
self.scope = None
|
self.scope = None
|
||||||
|
self.embedding_size = 128
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
@@ -194,7 +195,14 @@ class TFElectraModelTester:
|
|||||||
class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(TFElectraModel, TFElectraForMaskedLM, TFElectraForPreTraining, TFElectraForTokenClassification,)
|
(
|
||||||
|
TFElectraModel,
|
||||||
|
TFElectraForMaskedLM,
|
||||||
|
TFElectraForPreTraining,
|
||||||
|
TFElectraForTokenClassification,
|
||||||
|
TFElectraForMultipleChoice,
|
||||||
|
TFElectraForSequenceClassification,
|
||||||
|
)
|
||||||
if is_tf_available()
|
if is_tf_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user