PyTorch Transformer-XL
This commit is contained in:
@@ -1,6 +1,31 @@
|
|||||||
Transformer XL
|
Transformer XL
|
||||||
----------------------------------------------------
|
----------------------------------------------------
|
||||||
|
|
||||||
|
The Transformer-XL model was proposed in
|
||||||
|
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
|
||||||
|
by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||||
|
It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse
|
||||||
|
previously computed hidden-states to attend to longer context (memory).
|
||||||
|
This model also uses adaptive softmax inputs and outputs (tied).
|
||||||
|
|
||||||
|
The abstract from the paper is the following:
|
||||||
|
|
||||||
|
*Transformers have a potential of learning longer-term dependency, but are limited by a fixed-length context in the
|
||||||
|
setting of language modeling. We propose a novel neural architecture Transformer-XL that enables learning dependency
|
||||||
|
beyond a fixed length without disrupting temporal coherence. It consists of a segment-level recurrence mechanism and
|
||||||
|
a novel positional encoding scheme. Our method not only enables capturing longer-term dependency, but also resolves
|
||||||
|
the context fragmentation problem. As a result, Transformer-XL learns dependency that is 80% longer than RNNs and
|
||||||
|
450% longer than vanilla Transformers, achieves better performance on both short and long sequences, and is up
|
||||||
|
to 1,800+ times faster than vanilla Transformers during evaluation. Notably, we improve the state-of-the-art results
|
||||||
|
of bpc/perplexity to 0.99 on enwiki8, 1.08 on text8, 18.3 on WikiText-103, 21.8 on One Billion Word, and 54.5 on
|
||||||
|
Penn Treebank (without finetuning). When trained only on WikiText-103, Transformer-XL manages to generate reasonably
|
||||||
|
coherent, novel text articles with thousands of tokens.*
|
||||||
|
|
||||||
|
Tips:
|
||||||
|
|
||||||
|
- Transformer-XL uses relative sinusoidal positional embeddings so it's usually advised to pad the inputs on
|
||||||
|
the left rather than the right.
|
||||||
|
|
||||||
|
|
||||||
``TransfoXLConfig``
|
``TransfoXLConfig``
|
||||||
~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .configuration_transfo_xl import TransfoXLConfig
|
from .configuration_transfo_xl import TransfoXLConfig
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_transfo_xl_utilities import LogUniformSampler, ProjectedAdaptiveLogSoftmax, sample_logits
|
from .modeling_transfo_xl_utilities import LogUniformSampler, ProjectedAdaptiveLogSoftmax, sample_logits
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
@@ -508,21 +508,11 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
|
|||||||
self._init_bias(m.r_bias)
|
self._init_bias(m.r_bias)
|
||||||
|
|
||||||
|
|
||||||
TRANSFO_XL_START_DOCSTRING = r""" The Transformer-XL model was proposed in
|
TRANSFO_XL_START_DOCSTRING = r"""
|
||||||
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
|
|
||||||
by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
|
||||||
It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse
|
|
||||||
previously computed hidden-states to attend to longer context (memory).
|
|
||||||
This model also uses adaptive softmax inputs and outputs (tied).
|
|
||||||
|
|
||||||
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
|
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
||||||
refer to the PyTorch documentation for all matter related to general usage and behavior.
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
||||||
|
usage and behavior.
|
||||||
.. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`:
|
|
||||||
https://arxiv.org/abs/1901.02860
|
|
||||||
|
|
||||||
.. _`torch.nn.Module`:
|
|
||||||
https://pytorch.org/docs/stable/nn.html#module
|
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
config (:class:`~transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
|
config (:class:`~transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
|
||||||
@@ -531,24 +521,25 @@ TRANSFO_XL_START_DOCSTRING = r""" The Transformer-XL model was proposed in
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
TRANSFO_XL_INPUTS_DOCSTRING = r"""
|
TRANSFO_XL_INPUTS_DOCSTRING = r"""
|
||||||
Inputs:
|
Args:
|
||||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
|
|
||||||
the right or on the left.
|
|
||||||
Indices can be obtained using :class:`transformers.TransfoXLTokenizer`.
|
Indices can be obtained using :class:`transformers.TransfoXLTokenizer`.
|
||||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
||||||
**mems**: (`optional`)
|
|
||||||
list of ``torch.FloatTensor`` (one for each layer):
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||||
(see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
|
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||||
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
(see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
|
||||||
|
given to this model should not be passed as input ids as they have already been computed.
|
||||||
|
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
|
||||||
Mask to nullify selected heads of the self-attention modules.
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
|
||||||
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
input_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||||
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
than the model's internal embedding lookup matrix.
|
than the model's internal embedding lookup matrix.
|
||||||
"""
|
"""
|
||||||
@@ -557,34 +548,8 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
|
|||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
TRANSFO_XL_START_DOCSTRING,
|
TRANSFO_XL_START_DOCSTRING,
|
||||||
TRANSFO_XL_INPUTS_DOCSTRING,
|
|
||||||
)
|
)
|
||||||
class TransfoXLModel(TransfoXLPreTrainedModel):
|
class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||||
r"""
|
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
||||||
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
|
||||||
Sequence of hidden-states at the last layer of the model.
|
|
||||||
**mems**:
|
|
||||||
list of ``torch.FloatTensor`` (one for each layer):
|
|
||||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
|
||||||
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
|
|
||||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
||||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
|
||||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
||||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
||||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
|
|
||||||
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
|
|
||||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
|
||||||
outputs = model(input_ids)
|
|
||||||
last_hidden_states, mems = outputs[:2]
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -705,7 +670,38 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
|
|
||||||
return new_mems
|
return new_mems
|
||||||
|
|
||||||
|
@add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
|
||||||
def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None):
|
def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None):
|
||||||
|
r"""
|
||||||
|
Return:
|
||||||
|
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (config) and inputs:
|
||||||
|
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the last layer of the model.
|
||||||
|
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||||
|
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||||
|
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||||
|
should not be passed as input ids as they have already been computed.
|
||||||
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||||
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
|
||||||
|
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
|
||||||
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||||
|
outputs = model(input_ids)
|
||||||
|
last_hidden_states, mems = outputs[:2]
|
||||||
|
|
||||||
|
"""
|
||||||
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
|
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
|
||||||
# so we transpose here from shape [bsz, len] to shape [len, bsz]
|
# so we transpose here from shape [bsz, len] to shape [len, bsz]
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
@@ -805,44 +801,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
"""The Transformer-XL Model with a language modeling head on top
|
"""The Transformer-XL Model with a language modeling head on top
|
||||||
(adaptive softmax with weights tied to the adaptive input embeddings)""",
|
(adaptive softmax with weights tied to the adaptive input embeddings)""",
|
||||||
TRANSFO_XL_START_DOCSTRING,
|
TRANSFO_XL_START_DOCSTRING,
|
||||||
TRANSFO_XL_INPUTS_DOCSTRING,
|
|
||||||
)
|
)
|
||||||
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||||
r"""
|
|
||||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
|
||||||
Labels for language modeling.
|
|
||||||
Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
|
|
||||||
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
|
|
||||||
All labels set to ``-100`` are ignored (masked), the loss is only
|
|
||||||
computed for labels in ``[0, ..., config.vocab_size]``
|
|
||||||
|
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
|
||||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
|
||||||
Language modeling loss.
|
|
||||||
**prediction_scores**: ``None`` if ``labels`` is provided else ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
||||||
We don't output them when the loss is computed to speedup adaptive softmax decoding.
|
|
||||||
**mems**:
|
|
||||||
list of ``torch.FloatTensor`` (one for each layer):
|
|
||||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
|
||||||
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
|
|
||||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
|
||||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
|
||||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
||||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
|
||||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
|
|
||||||
model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
|
|
||||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
|
||||||
outputs = model(input_ids)
|
|
||||||
prediction_scores, mems = outputs[:2]
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -891,7 +851,47 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||||||
def init_mems(self, bsz):
|
def init_mems(self, bsz):
|
||||||
return self.transformer.init_mems(bsz)
|
return self.transformer.init_mems(bsz)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
|
||||||
def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None, labels=None):
|
def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for language modeling.
|
||||||
|
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
|
||||||
|
Indices are selected in ``[-100, 0, ..., config.vocab_size]``
|
||||||
|
All labels set to ``-100`` are ignored (masked), the loss is only
|
||||||
|
computed for labels in ``[0, ..., config.vocab_size]``
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:obj:`~transformers.GPT2Config`) and inputs:
|
||||||
|
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
|
||||||
|
Language modeling loss.
|
||||||
|
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
|
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||||
|
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||||
|
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||||
|
should not be passed as input ids as they have already been computed.
|
||||||
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||||
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
|
||||||
|
model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
|
||||||
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||||
|
outputs = model(input_ids)
|
||||||
|
prediction_scores, mems = outputs[:2]
|
||||||
|
|
||||||
|
"""
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
bsz, tgt_len = input_ids.size(0), input_ids.size(1)
|
bsz, tgt_len = input_ids.size(0), input_ids.size(1)
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user