From bd6e3018322766b3a71ae6675552607923c02636 Mon Sep 17 00:00:00 2001 From: Frankie Liuzzi Date: Fri, 22 May 2020 09:48:21 -0400 Subject: [PATCH] added functionality for electra classification head (#4257) * added functionality for electra classification head * unneeded dropout * Test ELECTRA for sequence classification * Style Co-authored-by: Frankie Co-authored-by: Lysandre --- src/transformers/__init__.py | 1 + src/transformers/modeling_auto.py | 2 + src/transformers/modeling_electra.py | 107 +++++++++++++++++++++++++++ tests/test_modeling_electra.py | 30 ++++++++ 4 files changed, 140 insertions(+) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3c016aa05e..db44d5293f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -321,6 +321,7 @@ if is_torch_available(): ElectraForMaskedLM, ElectraForTokenClassification, ElectraPreTrainedModel, + ElectraForSequenceClassification, ElectraModel, load_tf_weights_in_electra, ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP, diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 9106b4aeb5..8ed43c526b 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -88,6 +88,7 @@ from .modeling_electra import ( ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP, ElectraForMaskedLM, ElectraForPreTraining, + ElectraForSequenceClassification, ElectraForTokenClassification, ElectraModel, ) @@ -251,6 +252,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( (XLNetConfig, XLNetForSequenceClassification), (FlaubertConfig, FlaubertForSequenceClassification), (XLMConfig, XLMForSequenceClassification), + (ElectraConfig, ElectraForSequenceClassification), ] ) diff --git a/src/transformers/modeling_electra.py b/src/transformers/modeling_electra.py index d8e87bd944..46b6a418d1 100644 --- a/src/transformers/modeling_electra.py +++ b/src/transformers/modeling_electra.py @@ -3,6 +3,7 @@ import os import torch import torch.nn as nn +from torch.nn import CrossEntropyLoss, MSELoss from .activations import get_activation from .configuration_electra import ElectraConfig @@ -330,6 +331,112 @@ class ElectraModel(ElectraPreTrainedModel): return hidden_states +class ElectraClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = get_activation("gelu")(x) # although BERT uses tanh here, it seems Electra authors used gelu here + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@add_start_docstrings( + """ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of + the pooled output) e.g. for GLUE tasks. """, + ELECTRA_START_DOCSTRING, +) +class ElectraForSequenceClassification(ElectraPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.electra = ElectraModel(config) + self.classifier = ElectraClassificationHead(config) + + self.init_weights() + + @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[0, ..., config.num_labels - 1]`. + If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import BertTokenizer, BertForSequenceClassification + import torch + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = BertForSequenceClassification.from_pretrained('bert-base-uncased') + + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=labels) + + loss, logits = outputs[:2] + + """ + discriminator_hidden_states = self.electra( + input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds + ) + + sequence_output = discriminator_hidden_states[0] + logits = self.classifier(sequence_output) + + outputs = (logits,) + discriminator_hidden_states[2:] # add hidden states and attention if they are here + + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss,) + outputs + + return outputs # (loss), logits, (hidden_states), (attentions) + + @add_start_docstrings( """ Electra model with a binary classification head on top as used during pre-training for identifying generated diff --git a/tests/test_modeling_electra.py b/tests/test_modeling_electra.py index 88b5257d8b..3df77a25da 100644 --- a/tests/test_modeling_electra.py +++ b/tests/test_modeling_electra.py @@ -30,6 +30,7 @@ if is_torch_available(): ElectraForMaskedLM, ElectraForTokenClassification, ElectraForPreTraining, + ElectraForSequenceClassification, ) from transformers.modeling_electra import ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP @@ -242,6 +243,31 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length]) self.check_loss_output(result) + def create_and_check_electra_for_sequence_classification( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + fake_token_labels, + ): + config.num_labels = self.num_labels + model = ElectraForSequenceClassification(config) + model.to(torch_device) + model.eval() + loss, logits = model( + input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels + ) + result = { + "loss": loss, + "logits": logits, + } + self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) + self.check_loss_output(result) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -280,6 +306,10 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_electra_for_pretraining(*config_and_inputs) + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_electra_for_sequence_classification(*config_and_inputs) + @slow def test_model_from_pretrained(self): for model_name in list(ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: