From 0e36e515154f686e3927eb6269cc0b80d4669ba1 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Fri, 7 Aug 2020 15:30:57 +0200 Subject: [PATCH] Fix the tests for Electra (#6284) * Fix the tests for Electra * Apply style --- src/transformers/modeling_electra.py | 4 ++-- tests/test_modeling_tf_electra.py | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_electra.py b/src/transformers/modeling_electra.py index f41c230ca7..5d3d709259 100644 --- a/src/transformers/modeling_electra.py +++ b/src/transformers/modeling_electra.py @@ -857,7 +857,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel): super().__init__(config) self.electra = ElectraModel(config) - self.summary = SequenceSummary(config) + self.sequence_summary = SequenceSummary(config) self.classifier = nn.Linear(config.hidden_size, 1) self.init_weights() @@ -915,7 +915,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel): sequence_output = discriminator_hidden_states[0] - pooled_output = self.summary(sequence_output) + pooled_output = self.sequence_summary(sequence_output) logits = self.classifier(pooled_output) reshaped_logits = logits.view(-1, num_choices) diff --git a/tests/test_modeling_tf_electra.py b/tests/test_modeling_tf_electra.py index e986137567..90bb7277f4 100644 --- a/tests/test_modeling_tf_electra.py +++ b/tests/test_modeling_tf_electra.py @@ -63,6 +63,7 @@ class TFElectraModelTester: self.num_labels = 3 self.num_choices = 4 self.scope = None + self.embedding_size = 128 def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -194,7 +195,14 @@ class TFElectraModelTester: class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = ( - (TFElectraModel, TFElectraForMaskedLM, TFElectraForPreTraining, TFElectraForTokenClassification,) + ( + TFElectraModel, + TFElectraForMaskedLM, + TFElectraForPreTraining, + TFElectraForTokenClassification, + TFElectraForMultipleChoice, + TFElectraForSequenceClassification, + ) if is_tf_available() else () )