Transfoxl seq classification (#8868)
* Transfoxl sequence classification * Transfoxl sequence classification
This commit is contained in:
@@ -75,6 +75,11 @@ TransfoXLLMHeadModel
|
||||
.. autoclass:: transformers.TransfoXLLMHeadModel
|
||||
:members: forward
|
||||
|
||||
TransfoXLForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TransfoXLForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
TFTransfoXLModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -578,6 +578,7 @@ if is_torch_available():
|
||||
from .models.transfo_xl import (
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
AdaptiveEmbedding,
|
||||
TransfoXLForSequenceClassification,
|
||||
TransfoXLLMHeadModel,
|
||||
TransfoXLModel,
|
||||
TransfoXLPreTrainedModel,
|
||||
|
||||
@@ -157,7 +157,7 @@ from ..squeezebert.modeling_squeezebert import (
|
||||
SqueezeBertModel,
|
||||
)
|
||||
from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model
|
||||
from ..transfo_xl.modeling_transfo_xl import TransfoXLLMHeadModel, TransfoXLModel
|
||||
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
|
||||
from ..xlm.modeling_xlm import (
|
||||
XLMForMultipleChoice,
|
||||
XLMForQuestionAnsweringSimple,
|
||||
@@ -416,6 +416,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
|
||||
(ReformerConfig, ReformerForSequenceClassification),
|
||||
(CTRLConfig, CTRLForSequenceClassification),
|
||||
(TransfoXLConfig, TransfoXLForSequenceClassification),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ if is_torch_available():
|
||||
from .modeling_transfo_xl import (
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
AdaptiveEmbedding,
|
||||
TransfoXLForSequenceClassification,
|
||||
TransfoXLLMHeadModel,
|
||||
TransfoXLModel,
|
||||
TransfoXLPreTrainedModel,
|
||||
|
||||
@@ -23,6 +23,7 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...file_utils import (
|
||||
ModelOutput,
|
||||
@@ -632,6 +633,40 @@ class TransfoXLModelOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransfoXLSequenceClassifierOutputWithPast(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of sentence classification models.
|
||||
|
||||
Args:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see :obj:`mems`
|
||||
input) to speed up sequential decoding. The token ids which have their past given to this model should not
|
||||
be passed as input ids as they have already been computed.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
mems: List[torch.FloatTensor] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransfoXLLMHeadModelOutput(ModelOutput):
|
||||
"""
|
||||
@@ -1101,3 +1136,110 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
self.crit.cutoffs = new_cutoffs
|
||||
self.crit.cutoff_ends = [0] + new_cutoffs
|
||||
self.crit.n_token = new_num_tokens
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The Transformer-XL Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
:class:`~transformers.TransfoXLForSequenceClassification` uses the last token in order to do the classification, as
|
||||
other causal models (e.g. GPT-1) do.
|
||||
|
||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||
:obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
|
||||
row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
|
||||
guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
|
||||
the last value in each row of the batch).
|
||||
""",
|
||||
TRANSFO_XL_START_DOCSTRING,
|
||||
)
|
||||
class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.transformer = TransfoXLModel(config)
|
||||
self.score = nn.Linear(config.d_embed, self.num_labels, bias=False)
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="transfo-xl-wt103",
|
||||
output_type=TransfoXLSequenceClassifierOutputWithPast,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
mems=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
mems=mems,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = input_ids.shape[:2]
|
||||
else:
|
||||
batch_size, sequence_length = inputs_embeds.shape[:2]
|
||||
|
||||
assert (
|
||||
self.config.pad_token_id is not None or batch_size == 1
|
||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[range(batch_size), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TransfoXLSequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
mems=transformer_outputs.mems,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
@@ -1785,6 +1785,15 @@ class AdaptiveEmbedding:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class TransfoXLForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class TransfoXLLMHeadModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@@ -27,7 +27,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import TransfoXLConfig, TransfoXLLMHeadModel, TransfoXLModel
|
||||
from transformers import TransfoXLConfig, TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
|
||||
from transformers.models.transfo_xl.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@@ -56,6 +56,8 @@ class TransfoXLModelTester:
|
||||
self.scope = None
|
||||
self.seed = 1
|
||||
self.eos_token_id = 0
|
||||
self.num_labels = 3
|
||||
self.pad_token_id = self.vocab_size - 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -78,6 +80,7 @@ class TransfoXLModelTester:
|
||||
div_val=self.div_val,
|
||||
n_layer=self.num_hidden_layers,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, lm_labels)
|
||||
@@ -148,6 +151,14 @@ class TransfoXLModelTester:
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_transfo_xl_for_sequence_classification(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = TransfoXLForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids_1)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs
|
||||
@@ -157,7 +168,9 @@ class TransfoXLModelTester:
|
||||
|
||||
@require_torch
|
||||
class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(TransfoXLModel, TransfoXLLMHeadModel, TransfoXLForSequenceClassification) if is_torch_available() else ()
|
||||
)
|
||||
all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
@@ -204,6 +217,10 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
||||
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
|
||||
self.model_tester.check_transfo_xl_lm_head_output(output_result)
|
||||
|
||||
def test_transfo_xl_sequence_classification_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_transfo_xl_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
# xlnet cannot keep gradients in attentions or hidden states
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user