From f6b44e61903964aee07bdad7056a1ad9d0006e6f Mon Sep 17 00:00:00 2001 From: sandip Date: Wed, 2 Dec 2020 20:38:32 +0530 Subject: [PATCH] Transfoxl seq classification (#8868) * Transfoxl sequence classification * Transfoxl sequence classification --- docs/source/model_doc/transformerxl.rst | 5 + src/transformers/__init__.py | 1 + src/transformers/models/auto/modeling_auto.py | 3 +- .../models/transfo_xl/__init__.py | 1 + .../models/transfo_xl/modeling_transfo_xl.py | 142 ++++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 9 ++ tests/test_modeling_transfo_xl.py | 21 ++- 7 files changed, 179 insertions(+), 3 deletions(-) diff --git a/docs/source/model_doc/transformerxl.rst b/docs/source/model_doc/transformerxl.rst index e12847da61..e2a2d5d472 100644 --- a/docs/source/model_doc/transformerxl.rst +++ b/docs/source/model_doc/transformerxl.rst @@ -75,6 +75,11 @@ TransfoXLLMHeadModel .. autoclass:: transformers.TransfoXLLMHeadModel :members: forward +TransfoXLForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TransfoXLForSequenceClassification + :members: forward TFTransfoXLModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index bacfbb64e1..5d341bcfce 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -578,6 +578,7 @@ if is_torch_available(): from .models.transfo_xl import ( TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST, AdaptiveEmbedding, + TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel, TransfoXLPreTrainedModel, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index f2b30a8372..0a64f066d4 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -157,7 +157,7 @@ from ..squeezebert.modeling_squeezebert import ( SqueezeBertModel, ) from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model -from ..transfo_xl.modeling_transfo_xl import TransfoXLLMHeadModel, TransfoXLModel +from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel from ..xlm.modeling_xlm import ( XLMForMultipleChoice, XLMForQuestionAnsweringSimple, @@ -416,6 +416,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( (OpenAIGPTConfig, OpenAIGPTForSequenceClassification), (ReformerConfig, ReformerForSequenceClassification), (CTRLConfig, CTRLForSequenceClassification), + (TransfoXLConfig, TransfoXLForSequenceClassification), ] ) diff --git a/src/transformers/models/transfo_xl/__init__.py b/src/transformers/models/transfo_xl/__init__.py index 2dc009b7f6..0d0bd7a56f 100644 --- a/src/transformers/models/transfo_xl/__init__.py +++ b/src/transformers/models/transfo_xl/__init__.py @@ -11,6 +11,7 @@ if is_torch_available(): from .modeling_transfo_xl import ( TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST, AdaptiveEmbedding, + TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel, TransfoXLPreTrainedModel, diff --git a/src/transformers/models/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_transfo_xl.py index 63ab53e07e..3e579566f6 100644 --- a/src/transformers/models/transfo_xl/modeling_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_transfo_xl.py @@ -23,6 +23,7 @@ from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn import CrossEntropyLoss, MSELoss from ...file_utils import ( ModelOutput, @@ -632,6 +633,40 @@ class TransfoXLModelOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class TransfoXLSequenceClassifierOutputWithPast(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): + Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see :obj:`mems` + input) to speed up sequential decoding. The token ids which have their past given to this model should not + be passed as input ids as they have already been computed. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mems: List[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + @dataclass class TransfoXLLMHeadModelOutput(ModelOutput): """ @@ -1101,3 +1136,110 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): self.crit.cutoffs = new_cutoffs self.crit.cutoff_ends = [0] + new_cutoffs self.crit.n_token = new_num_tokens + + +@add_start_docstrings( + """ + The Transformer-XL Model transformer with a sequence classification head on top (linear layer). + + :class:`~transformers.TransfoXLForSequenceClassification` uses the last token in order to do the classification, as + other causal models (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each + row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot + guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take + the last value in each row of the batch). + """, + TRANSFO_XL_START_DOCSTRING, +) +class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = TransfoXLModel(config) + self.score = nn.Linear(config.d_embed, self.num_labels, bias=False) + self.init_weights() + + @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="transfo-xl-wt103", + output_type=TransfoXLSequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + mems=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + mems=mems, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[range(batch_size), sequence_lengths] + + loss = None + if labels is not None: + if self.num_labels == 1: + loss_fct = MSELoss() + loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TransfoXLSequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + mems=transformer_outputs.mems, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index b34e356b00..3dd8acffac 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1785,6 +1785,15 @@ class AdaptiveEmbedding: requires_pytorch(self) +class TransfoXLForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + class TransfoXLLMHeadModel: def __init__(self, *args, **kwargs): requires_pytorch(self) diff --git a/tests/test_modeling_transfo_xl.py b/tests/test_modeling_transfo_xl.py index 7f2fa26cce..e8bed3cfbd 100644 --- a/tests/test_modeling_transfo_xl.py +++ b/tests/test_modeling_transfo_xl.py @@ -27,7 +27,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor if is_torch_available(): import torch - from transformers import TransfoXLConfig, TransfoXLLMHeadModel, TransfoXLModel + from transformers import TransfoXLConfig, TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel from transformers.models.transfo_xl.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST @@ -56,6 +56,8 @@ class TransfoXLModelTester: self.scope = None self.seed = 1 self.eos_token_id = 0 + self.num_labels = 3 + self.pad_token_id = self.vocab_size - 1 def prepare_config_and_inputs(self): input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -78,6 +80,7 @@ class TransfoXLModelTester: div_val=self.div_val, n_layer=self.num_hidden_layers, eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, ) return (config, input_ids_1, input_ids_2, lm_labels) @@ -148,6 +151,14 @@ class TransfoXLModelTester: [(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers, ) + def create_and_check_transfo_xl_for_sequence_classification(self, config, input_ids_1, input_ids_2, lm_labels): + config.num_labels = self.num_labels + model = TransfoXLForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids_1) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs @@ -157,7 +168,9 @@ class TransfoXLModelTester: @require_torch class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): - all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else () + all_model_classes = ( + (TransfoXLModel, TransfoXLLMHeadModel, TransfoXLForSequenceClassification) if is_torch_available() else () + ) all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else () test_pruning = False test_torchscript = False @@ -204,6 +217,10 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs) self.model_tester.check_transfo_xl_lm_head_output(output_result) + def test_transfo_xl_sequence_classification_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_transfo_xl_for_sequence_classification(*config_and_inputs) + def test_retain_grad_hidden_states_attentions(self): # xlnet cannot keep gradients in attentions or hidden states return