Fix the tests for Electra (#6284)
* Fix the tests for Electra * Apply style
This commit is contained in:
@@ -857,7 +857,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
self.electra = ElectraModel(config)
|
||||
self.summary = SequenceSummary(config)
|
||||
self.sequence_summary = SequenceSummary(config)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
self.init_weights()
|
||||
@@ -915,7 +915,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
|
||||
|
||||
sequence_output = discriminator_hidden_states[0]
|
||||
|
||||
pooled_output = self.summary(sequence_output)
|
||||
pooled_output = self.sequence_summary(sequence_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = logits.view(-1, num_choices)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user