From 4a9e502a3602d21a6005259fb57b8e1c78101410 Mon Sep 17 00:00:00 2001 From: elk-cloner Date: Tue, 1 Dec 2020 12:19:27 +0330 Subject: [PATCH] Ctrl for sequence classification (#8812) * add CTRLForSequenceClassification * pass local test * merge with master * fix modeling test for sequence classification * fix deco * fix assert --- docs/source/model_doc/ctrl.rst | 7 ++ src/transformers/__init__.py | 8 +- src/transformers/models/auto/modeling_auto.py | 3 +- src/transformers/models/ctrl/__init__.py | 8 +- src/transformers/models/ctrl/modeling_ctrl.py | 118 +++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 9 ++ tests/test_modeling_ctrl.py | 21 +++- 7 files changed, 167 insertions(+), 7 deletions(-) diff --git a/docs/source/model_doc/ctrl.rst b/docs/source/model_doc/ctrl.rst index 86bf6dea78..3da237459c 100644 --- a/docs/source/model_doc/ctrl.rst +++ b/docs/source/model_doc/ctrl.rst @@ -65,6 +65,13 @@ CTRLLMHeadModel :members: forward +CTRLForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.CTRLForSequenceClassification + :members: forward + + TFCTRLModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 413ea4a45e..2ae5ec3825 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -391,7 +391,13 @@ if is_torch_available(): CamembertForTokenClassification, CamembertModel, ) - from .models.ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLLMHeadModel, CTRLModel, CTRLPreTrainedModel + from .models.ctrl import ( + CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, + CTRLForSequenceClassification, + CTRLLMHeadModel, + CTRLModel, + CTRLPreTrainedModel, + ) from .models.deberta import ( DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, DebertaForSequenceClassification, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index da89c387cd..f2b30a8372 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -60,7 +60,7 @@ from ..camembert.modeling_camembert import ( CamembertForTokenClassification, CamembertModel, ) -from ..ctrl.modeling_ctrl import CTRLLMHeadModel, CTRLModel +from ..ctrl.modeling_ctrl import CTRLForSequenceClassification, CTRLLMHeadModel, CTRLModel from ..deberta.modeling_deberta import DebertaForSequenceClassification, DebertaModel from ..distilbert.modeling_distilbert import ( DistilBertForMaskedLM, @@ -415,6 +415,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( (GPT2Config, GPT2ForSequenceClassification), (OpenAIGPTConfig, OpenAIGPTForSequenceClassification), (ReformerConfig, ReformerForSequenceClassification), + (CTRLConfig, CTRLForSequenceClassification), ] ) diff --git a/src/transformers/models/ctrl/__init__.py b/src/transformers/models/ctrl/__init__.py index d32bc87080..3c94882d91 100644 --- a/src/transformers/models/ctrl/__init__.py +++ b/src/transformers/models/ctrl/__init__.py @@ -8,7 +8,13 @@ from .tokenization_ctrl import CTRLTokenizer if is_torch_available(): - from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLLMHeadModel, CTRLModel, CTRLPreTrainedModel + from .modeling_ctrl import ( + CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, + CTRLForSequenceClassification, + CTRLLMHeadModel, + CTRLModel, + CTRLPreTrainedModel, + ) if is_tf_available(): from .modeling_tf_ctrl import ( diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 8e2862dbe7..f85dd645ad 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -18,10 +18,10 @@ import numpy as np import torch import torch.nn as nn -from torch.nn import CrossEntropyLoss +from torch.nn import CrossEntropyLoss, MSELoss from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging from .configuration_ctrl import CTRLConfig @@ -571,3 +571,117 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The CTRL Model transformer with a sequence classification head on top (linear layer). + :class:`~transformers.CTRLForSequenceClassification` uses the last token in order to do the classification, as + other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the + position of the last token. If a :obj:`pad_token_id` is defined in the configuration, it finds the last token that + is not a padding token in each row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each + row of the batch. Since it cannot guess the padding tokens when :obj:`inputs_embeds` are passed instead of + :obj:`input_ids`, it does the same (take the last value in each row of the batch). + """, + CTRL_START_DOCSTRING, +) +class CTRLForSequenceClassification(CTRLPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = CTRLModel(config) + self.classifier = nn.Linear(config.n_embd, self.num_labels, bias=False) + + self.init_weights() + + @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="ctrl", + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + logits = self.classifier(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + f"unexpected if using padding tokens in conjuction with `inputs_embeds.`" + ) + + pooled_logits = logits[range(batch_size), sequence_lengths] + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 543ed0bd08..b34e356b00 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -634,6 +634,15 @@ class CamembertModel: CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = None +class CTRLForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + class CTRLLMHeadModel: def __init__(self, *args, **kwargs): requires_pytorch(self) diff --git a/tests/test_modeling_ctrl.py b/tests/test_modeling_ctrl.py index 030a7bf9fe..d225462356 100644 --- a/tests/test_modeling_ctrl.py +++ b/tests/test_modeling_ctrl.py @@ -26,7 +26,13 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention if is_torch_available(): import torch - from transformers import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLConfig, CTRLLMHeadModel, CTRLModel + from transformers import ( + CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, + CTRLConfig, + CTRLForSequenceClassification, + CTRLLMHeadModel, + CTRLModel, + ) class CTRLModelTester: @@ -57,6 +63,7 @@ class CTRLModelTester: self.num_labels = 3 self.num_choices = 4 self.scope = None + self.pad_token_id = self.vocab_size - 1 def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -94,6 +101,7 @@ class CTRLModelTester: n_ctx=self.max_position_embeddings, # type_vocab_size=self.type_vocab_size, # initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, ) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) @@ -149,11 +157,20 @@ class CTRLModelTester: return config, inputs_dict + def create_and_check_ctrl_for_sequence_classification(self, config, input_ids, head_mask, token_type_ids, *args): + config.num_labels = self.num_labels + model = CTRLForSequenceClassification(config) + model.to(torch_device) + model.eval() + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + result = model(input_ids, token_type_ids=token_type_ids, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + @require_torch class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): - all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else () + all_model_classes = (CTRLModel, CTRLLMHeadModel, CTRLForSequenceClassification) if is_torch_available() else () all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else () test_pruning = True test_torchscript = False