Transfoxl seq classification (#8868)
* Transfoxl sequence classification * Transfoxl sequence classification
This commit is contained in:
@@ -75,6 +75,11 @@ TransfoXLLMHeadModel
|
|||||||
.. autoclass:: transformers.TransfoXLLMHeadModel
|
.. autoclass:: transformers.TransfoXLLMHeadModel
|
||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
TransfoXLForSequenceClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TransfoXLForSequenceClassification
|
||||||
|
:members: forward
|
||||||
|
|
||||||
TFTransfoXLModel
|
TFTransfoXLModel
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|||||||
@@ -578,6 +578,7 @@ if is_torch_available():
|
|||||||
from .models.transfo_xl import (
|
from .models.transfo_xl import (
|
||||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
AdaptiveEmbedding,
|
AdaptiveEmbedding,
|
||||||
|
TransfoXLForSequenceClassification,
|
||||||
TransfoXLLMHeadModel,
|
TransfoXLLMHeadModel,
|
||||||
TransfoXLModel,
|
TransfoXLModel,
|
||||||
TransfoXLPreTrainedModel,
|
TransfoXLPreTrainedModel,
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ from ..squeezebert.modeling_squeezebert import (
|
|||||||
SqueezeBertModel,
|
SqueezeBertModel,
|
||||||
)
|
)
|
||||||
from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model
|
from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model
|
||||||
from ..transfo_xl.modeling_transfo_xl import TransfoXLLMHeadModel, TransfoXLModel
|
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
|
||||||
from ..xlm.modeling_xlm import (
|
from ..xlm.modeling_xlm import (
|
||||||
XLMForMultipleChoice,
|
XLMForMultipleChoice,
|
||||||
XLMForQuestionAnsweringSimple,
|
XLMForQuestionAnsweringSimple,
|
||||||
@@ -416,6 +416,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
|||||||
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
|
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
|
||||||
(ReformerConfig, ReformerForSequenceClassification),
|
(ReformerConfig, ReformerForSequenceClassification),
|
||||||
(CTRLConfig, CTRLForSequenceClassification),
|
(CTRLConfig, CTRLForSequenceClassification),
|
||||||
|
(TransfoXLConfig, TransfoXLForSequenceClassification),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ if is_torch_available():
|
|||||||
from .modeling_transfo_xl import (
|
from .modeling_transfo_xl import (
|
||||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
AdaptiveEmbedding,
|
AdaptiveEmbedding,
|
||||||
|
TransfoXLForSequenceClassification,
|
||||||
TransfoXLLMHeadModel,
|
TransfoXLLMHeadModel,
|
||||||
TransfoXLModel,
|
TransfoXLModel,
|
||||||
TransfoXLPreTrainedModel,
|
TransfoXLPreTrainedModel,
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from typing import List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@@ -632,6 +633,40 @@ class TransfoXLModelOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TransfoXLSequenceClassifierOutputWithPast(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for outputs of sentence classification models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||||
|
Classification (or regression if config.num_labels==1) loss.
|
||||||
|
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||||
|
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||||
|
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||||
|
Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see :obj:`mems`
|
||||||
|
input) to speed up sequential decoding. The token ids which have their past given to this model should not
|
||||||
|
be passed as input ids as they have already been computed.
|
||||||
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
logits: torch.FloatTensor = None
|
||||||
|
mems: List[torch.FloatTensor] = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TransfoXLLMHeadModelOutput(ModelOutput):
|
class TransfoXLLMHeadModelOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@@ -1101,3 +1136,110 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||||||
self.crit.cutoffs = new_cutoffs
|
self.crit.cutoffs = new_cutoffs
|
||||||
self.crit.cutoff_ends = [0] + new_cutoffs
|
self.crit.cutoff_ends = [0] + new_cutoffs
|
||||||
self.crit.n_token = new_num_tokens
|
self.crit.n_token = new_num_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The Transformer-XL Model transformer with a sequence classification head on top (linear layer).
|
||||||
|
|
||||||
|
:class:`~transformers.TransfoXLForSequenceClassification` uses the last token in order to do the classification, as
|
||||||
|
other causal models (e.g. GPT-1) do.
|
||||||
|
|
||||||
|
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||||
|
:obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
|
||||||
|
row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
|
||||||
|
guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
|
||||||
|
the last value in each row of the batch).
|
||||||
|
""",
|
||||||
|
TRANSFO_XL_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
|
||||||
|
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.transformer = TransfoXLModel(config)
|
||||||
|
self.score = nn.Linear(config.d_embed, self.num_labels, bias=False)
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING)
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
checkpoint="transfo-xl-wt103",
|
||||||
|
output_type=TransfoXLSequenceClassifierOutputWithPast,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
mems=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||||
|
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||||
|
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
mems=mems,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
hidden_states = transformer_outputs[0]
|
||||||
|
logits = self.score(hidden_states)
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size, sequence_length = input_ids.shape[:2]
|
||||||
|
else:
|
||||||
|
batch_size, sequence_length = inputs_embeds.shape[:2]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.config.pad_token_id is not None or batch_size == 1
|
||||||
|
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||||
|
if self.config.pad_token_id is None:
|
||||||
|
sequence_lengths = -1
|
||||||
|
else:
|
||||||
|
if input_ids is not None:
|
||||||
|
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
||||||
|
else:
|
||||||
|
sequence_lengths = -1
|
||||||
|
logger.warning(
|
||||||
|
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||||
|
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||||
|
)
|
||||||
|
|
||||||
|
pooled_logits = logits[range(batch_size), sequence_lengths]
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
if self.num_labels == 1:
|
||||||
|
loss_fct = MSELoss()
|
||||||
|
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
|
||||||
|
else:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (pooled_logits,) + transformer_outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TransfoXLSequenceClassifierOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=pooled_logits,
|
||||||
|
mems=transformer_outputs.mems,
|
||||||
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
|
attentions=transformer_outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1785,6 +1785,15 @@ class AdaptiveEmbedding:
|
|||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
|
class TransfoXLForSequenceClassification:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
class TransfoXLLMHeadModel:
|
class TransfoXLLMHeadModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import TransfoXLConfig, TransfoXLLMHeadModel, TransfoXLModel
|
from transformers import TransfoXLConfig, TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
|
||||||
from transformers.models.transfo_xl.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.transfo_xl.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
|
|
||||||
@@ -56,6 +56,8 @@ class TransfoXLModelTester:
|
|||||||
self.scope = None
|
self.scope = None
|
||||||
self.seed = 1
|
self.seed = 1
|
||||||
self.eos_token_id = 0
|
self.eos_token_id = 0
|
||||||
|
self.num_labels = 3
|
||||||
|
self.pad_token_id = self.vocab_size - 1
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
@@ -78,6 +80,7 @@ class TransfoXLModelTester:
|
|||||||
div_val=self.div_val,
|
div_val=self.div_val,
|
||||||
n_layer=self.num_hidden_layers,
|
n_layer=self.num_hidden_layers,
|
||||||
eos_token_id=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (config, input_ids_1, input_ids_2, lm_labels)
|
return (config, input_ids_1, input_ids_2, lm_labels)
|
||||||
@@ -148,6 +151,14 @@ class TransfoXLModelTester:
|
|||||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_and_check_transfo_xl_for_sequence_classification(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||||
|
config.num_labels = self.num_labels
|
||||||
|
model = TransfoXLForSequenceClassification(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids_1)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs
|
(config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs
|
||||||
@@ -157,7 +168,9 @@ class TransfoXLModelTester:
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
|
all_model_classes = (
|
||||||
|
(TransfoXLModel, TransfoXLLMHeadModel, TransfoXLForSequenceClassification) if is_torch_available() else ()
|
||||||
|
)
|
||||||
all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else ()
|
all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
@@ -204,6 +217,10 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
|||||||
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
|
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
|
||||||
self.model_tester.check_transfo_xl_lm_head_output(output_result)
|
self.model_tester.check_transfo_xl_lm_head_output(output_result)
|
||||||
|
|
||||||
|
def test_transfo_xl_sequence_classification_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_transfo_xl_for_sequence_classification(*config_and_inputs)
|
||||||
|
|
||||||
def test_retain_grad_hidden_states_attentions(self):
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
# xlnet cannot keep gradients in attentions or hidden states
|
# xlnet cannot keep gradients in attentions or hidden states
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user