[Generate] Facilitate PyTorch generate using ModelOutputs (#6735)

* fix generate for GPT2 Double Head

* fix gpt2 double head model

* fix  bart / t5

* also add for no beam search

* fix no beam search

* fix encoder decoder

* simplify t5

* simplify t5

* fix t5 tests

* fix BART

* fix transfo-xl

* fix conflict

* integrating sylvains and sams comments

* fix tf past_decoder_key_values

* fix enc dec test
This commit is contained in:
Patrick von Platen
2020-09-01 12:38:25 +02:00
committed by GitHub
parent 397f819615
commit afc4ece462
20 changed files with 393 additions and 259 deletions

View File

@@ -20,6 +20,7 @@ import torch
from torch import Tensor
from torch.nn import functional as F
from .file_utils import ModelOutput
from .utils import logging
@@ -46,14 +47,6 @@ class GenerationMixin:
"""
return logits
def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
if len(outputs) <= 1 or use_cache is False:
return False
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
return False
return True
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
"""
Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
@@ -137,7 +130,7 @@ class GenerationMixin:
attention_mask: Optional[torch.LongTensor] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
**model_specific_kwargs
**model_kwargs
) -> torch.LongTensor:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
@@ -208,7 +201,7 @@ class GenerationMixin:
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
model_specific_kwargs:
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
Return:
@@ -400,7 +393,7 @@ class GenerationMixin:
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
@@ -428,8 +421,8 @@ class GenerationMixin:
cur_len = 1
assert (
batch_size == encoder_outputs[0].shape[0]
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
batch_size == encoder_outputs.last_hidden_state.shape[0]
), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
expanded_batch_idxs = (
@@ -439,11 +432,16 @@ class GenerationMixin:
.view(-1)
.to(input_ids.device)
)
# expand encoder_outputs
encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
0, expanded_batch_idxs
)
# save encoder_outputs in `model_kwargs`
model_kwargs["encoder_outputs"] = encoder_outputs
else:
encoder_outputs = None
cur_len = input_ids.shape[-1]
assert (
@@ -471,10 +469,9 @@ class GenerationMixin:
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
model_specific_kwargs=model_specific_kwargs,
model_kwargs=model_kwargs,
)
else:
output = self._generate_no_beam_search(
@@ -492,10 +489,9 @@ class GenerationMixin:
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
model_specific_kwargs=model_specific_kwargs,
model_kwargs=model_kwargs,
)
return output
@@ -516,10 +512,9 @@ class GenerationMixin:
pad_token_id,
eos_token_id,
batch_size,
encoder_outputs,
attention_mask,
use_cache,
model_specific_kwargs,
model_kwargs,
):
"""Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
@@ -528,15 +523,14 @@ class GenerationMixin:
unfinished_sents = input_ids.new(batch_size).fill_(1)
sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = (encoder_outputs, None) if encoder_outputs is not None else None
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
outputs = self(**model_inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :]
scores = self.postprocess_next_token_scores(
scores=next_token_logits,
@@ -553,8 +547,10 @@ class GenerationMixin:
)
# if model has past, then set the past variable to speed up decoding
if self._use_cache(outputs, use_cache):
past = outputs[1]
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
@@ -621,10 +617,9 @@ class GenerationMixin:
length_penalty,
num_beams,
vocab_size,
encoder_outputs,
attention_mask,
use_cache,
model_specific_kwargs,
model_kwargs,
):
"""Generate sequences for each example with beam search."""
@@ -643,21 +638,24 @@ class GenerationMixin:
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states
past = (encoder_outputs, None) if encoder_outputs is not None else None
past = None
# done sentences
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._use_cache(outputs, use_cache):
past = outputs[1]
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
if self.config.is_encoder_decoder and do_sample is False:
# TODO (PVP) still a bit hacky here - there might be a better solution
next_token_logits = self.adjust_logits_during_generation(

View File

@@ -111,15 +111,15 @@ BART_INPUTS_DOCSTRING = r"""
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
See diagram 1 in the paper for more info on the default strategy
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up decoding.
If ``decoder_past_key_value_states`` are used, the user can optionally input only the last
If ``past_key_values`` are used, the user can optionally input only the last
``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape
:obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If `use_cache` is True, ``decoder_past_key_values`` are returned and can be used to speed up decoding (see
``decoder_past_key_values``).
If `use_cache` is True, ``past_key_values`` are returned and can be used to speed up decoding (see
``past_key_values``).
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
@@ -502,7 +502,7 @@ class BartDecoder(nn.Module):
encoder_padding_mask,
decoder_padding_mask,
decoder_causal_mask,
decoder_past_key_values=None,
past_key_values=None,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
@@ -519,7 +519,7 @@ class BartDecoder(nn.Module):
encoder_hidden_states: output from the encoder, used for
encoder-side attention
encoder_padding_mask: for ignoring pad tokens
decoder_past_key_values (dict or None): dictionary used for storing state during generation
past_key_values (dict or None): dictionary used for storing state during generation
Returns:
BaseModelOutputWithPast or tuple:
@@ -530,10 +530,16 @@ class BartDecoder(nn.Module):
"""
if "decoder_cached_states" in unused:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
decoder_past_key_values = unused.pop("decoder_cached_states")
past_key_values = unused.pop("decoder_cached_states")
if "decoder_past_key_values" in unused:
warnings.warn(
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = unused.pop("decoder_past_key_values")
# check attention mask and invert
if encoder_padding_mask is not None:
@@ -568,7 +574,7 @@ class BartDecoder(nn.Module):
if self.training and (dropout_probability < self.layerdrop):
continue
layer_state = decoder_past_key_values[idx] if decoder_past_key_values is not None else None
layer_state = past_key_values[idx] if past_key_values is not None else None
x, layer_self_attn, layer_past = decoder_layer(
x,
@@ -594,10 +600,7 @@ class BartDecoder(nn.Module):
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
if use_cache:
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
else:
next_cache = None
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
@@ -869,13 +872,19 @@ class BartModel(PretrainedBartModel):
decoder_input_ids=None,
encoder_outputs: Optional[Tuple] = None,
decoder_attention_mask=None,
decoder_past_key_values=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
if "decoder_past_key_values" in kwargs:
warnings.warn(
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("decoder_past_key_values")
if decoder_input_ids is None:
use_cache = False
@@ -924,7 +933,7 @@ class BartModel(PretrainedBartModel):
attention_mask,
decoder_padding_mask,
decoder_causal_mask=causal_mask,
decoder_past_key_values=decoder_past_key_values,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -936,7 +945,7 @@ class BartModel(PretrainedBartModel):
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_past_key_values=decoder_outputs.past_key_values,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
@@ -994,7 +1003,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_past_key_values=None,
past_key_values=None,
labels=None,
use_cache=None,
output_attentions=None,
@@ -1037,10 +1046,16 @@ class BartForConditionalGeneration(PretrainedBartModel):
labels = unused.pop("lm_labels")
if "decoder_cached_states" in unused:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
decoder_past_key_values = unused.pop("decoder_cached_states")
past_key_values = unused.pop("decoder_cached_states")
if "decoder_past_key_values" in unused:
warnings.warn(
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = unused.pop("decoder_past_key_values")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
@@ -1054,7 +1069,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_past_key_values=decoder_past_key_values,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -1075,7 +1090,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return Seq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
decoder_past_key_values=outputs.decoder_past_key_values,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
@@ -1083,14 +1098,13 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_attentions=outputs.encoder_attentions,
)
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
encoder_outputs, decoder_past_key_values = past
def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
):
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"decoder_past_key_values": decoder_past_key_values,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
@@ -1109,20 +1123,14 @@ class BartForConditionalGeneration(PretrainedBartModel):
@staticmethod
def _reorder_cache(past, beam_idx):
((enc_out, enc_mask), decoder_past_key_values) = past
reordered_past = []
for layer_past in decoder_past_key_values:
for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
}
reordered_past.append(layer_past_new)
new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx)
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
past = ((new_enc_out, new_enc_mask), reordered_past)
return past
return reordered_past
def get_encoder(self):
return self.model.encoder
@@ -1208,7 +1216,7 @@ class BartForSequenceClassification(PretrainedBartModel):
return Seq2SeqSequenceClassifierOutput(
loss=loss,
logits=logits,
decoder_past_key_values=outputs.decoder_past_key_values,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
@@ -1316,7 +1324,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
decoder_past_key_values=outputs.decoder_past_key_values,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,

View File

@@ -19,13 +19,79 @@ from typing import Optional
from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable, replace_return_docstrings
from .modeling_outputs import Seq2SeqLMOutput
from .modeling_utils import PreTrainedModel
from .utils import logging
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "EncoderDecoderConfig"
ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to inialize a sequence-to-sequnece model with any pretrained autoencoding model as the encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via :meth:`~transformers.AutoModel.from_pretrained` function and the decoder is loaded via :meth:`~transformers.AutoModelForCausalLM.from_pretrained` function.
Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream generative task, *i.e.* summarization.
The effectiveness of initializing sequence-to-sequence models with pre-trained checkpoints for sequence generation tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks <https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
After such an Encoder Decoder model has been trained / fine-tuned, it can be saved / loaded just like any other models (see Examples for more information).
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#module>`__ sub-class. Use it as a
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
Parameters:
config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary for the encoder.
Indices can be obtained using :class:`~transformers.PretrainedTokenizer`.
See :meth:`~transformers.PreTrainedTokenizer.encode` and
:meth:`~transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices for the encoder.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
encoder_outputs (:obj:`tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
This tuple must consist of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`)
`last_hidden_state` (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`) is a tensor of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training to the decoder.
Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss for the decoder.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the model will return a :class:`~transformers.file_utils.Seq2SeqLMOutput` instead of a
plain tuple.
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
- Without a prefix which will be input as ``**encoder_kwargs`` for the encoder forward function.
- With a `decoder_` prefix which will be input as ``**decoder_kwargs`` for the decoder forward function.
"""
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class EncoderDecoderModel(PreTrainedModel):
r"""
:class:`~transformers.EncoderDecoder` is a generic model class that will be
@@ -206,6 +272,8 @@ class EncoderDecoderModel(PreTrainedModel):
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
return cls(encoder=encoder, decoder=decoder, config=config)
@add_start_docstrings_to_callable(ENCODER_DECODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@@ -216,47 +284,11 @@ class EncoderDecoderModel(PreTrainedModel):
decoder_attention_mask=None,
decoder_inputs_embeds=None,
labels=None,
return_dict=None,
**kwargs,
):
"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary for the encoder.
Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices for the encoder.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training to the decoder.
Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss for the decoder.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
- With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function.
r"""
Returns:
Examples::
@@ -264,19 +296,25 @@ class EncoderDecoderModel(PreTrainedModel):
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
>>> # forward
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
>>> # training
>>> loss, outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)[:2]
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids, return_dict=True)
>>> loss, logits = outputs.loss, outputs.logits
>>> # save and load from pretrained
>>> model.save_pretrained("bert2bert")
>>> model = EncoderDecoderModel.from_pretrained("bert2bert")
>>> # generation
>>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
@@ -289,7 +327,7 @@ class EncoderDecoderModel(PreTrainedModel):
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
return_dict=False,
return_dict=return_dict,
**kwargs_encoder,
)
@@ -303,23 +341,28 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
labels=labels,
return_dict=False,
return_dict=return_dict,
**kwargs_decoder,
)
# TODO(PVP): currently it is not possible to use `past`
# with the encoder/decoder framework -> should be implemented
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput(
loss=decoder_outputs.loss,
logits=decoder_outputs.logits,
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
return decoder_outputs + encoder_outputs
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
# first step
if type(past) is tuple:
encoder_outputs, _ = past
else:
encoder_outputs = (past,)
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, encoder_outputs, **kwargs):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
@@ -335,7 +378,7 @@ class EncoderDecoderModel(PreTrainedModel):
input_dict["decoder_use_cache"] = decoder_inputs["use_cache"]
if "past_key_values" in decoder_inputs:
input_dict["decoder_past_key_values"] = decoder_inputs["past_key_values"]
input_dict["past_key_values"] = decoder_inputs["past_key_values"]
return input_dict

View File

@@ -353,11 +353,11 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
Language modeling loss.
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
Multiple choice classification loss.
lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
@@ -380,9 +380,9 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
heads.
"""
lm_loss: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None
mc_loss: Optional[torch.FloatTensor] = None
lm_logits: torch.FloatTensor = None
logits: torch.FloatTensor = None
mc_logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@@ -777,6 +777,17 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
}
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
@@ -893,9 +904,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
return ((lm_loss,) + output) if lm_loss is not None else output
return GPT2DoubleHeadsModelOutput(
lm_loss=lm_loss,
loss=lm_loss,
mc_loss=mc_loss,
lm_logits=lm_logits,
logits=lm_logits,
mc_logits=mc_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,

View File

@@ -300,11 +300,11 @@ class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
Language modeling loss.
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
Multiple choice classification loss.
lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
@@ -321,9 +321,9 @@ class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
heads.
"""
lm_loss: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None
mc_loss: Optional[torch.FloatTensor] = None
lm_logits: torch.FloatTensor = None
logits: torch.FloatTensor = None
mc_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@@ -713,9 +713,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
return ((lm_loss,) + output) if lm_loss is not None else output
return OpenAIGPTDoubleHeadsModelOutput(
lm_loss=lm_loss,
loss=lm_loss,
mc_loss=mc_loss,
lm_logits=lm_logits,
logits=lm_logits,
mc_logits=mc_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,

View File

@@ -109,13 +109,13 @@ class Seq2SeqModelOutput(ModelOutput):
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
If ``decoder_past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
If ``past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -143,7 +143,7 @@ class Seq2SeqModelOutput(ModelOutput):
"""
last_hidden_state: torch.FloatTensor
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
@@ -255,12 +255,12 @@ class Seq2SeqLMOutput(ModelOutput):
Languaged modeling loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -289,7 +289,7 @@ class Seq2SeqLMOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
@@ -365,12 +365,12 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -399,7 +399,7 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
@@ -511,12 +511,12 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
Span-start scores (before SoftMax).
end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -546,7 +546,7 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
start_logits: torch.FloatTensor = None
end_logits: torch.FloatTensor = None
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None

View File

@@ -838,27 +838,27 @@ T5_INPUTS_DOCSTRING = r"""
Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
If `decoder_past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_values`).
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
`T5 Training <./t5.html#training>`__. If decoder_input_ids and decoder_inputs_embeds are both None,
decoder_input_ids takes the value of input_ids.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up decoding.
If `decoder_past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If `use_cache` is True, `decoder_past_key_values` are returned and can be used to speed up decoding (see `decoder_past_key_values`).
If `use_cache` is True, `past_key_values` are returned and can be used to speed up decoding (see `past_key_values`).
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
If `decoder_past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_values`).
If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`).
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. If decoder_input_ids and decoder_inputs_embeds are both None,
decoder_inputs_embeds takes the value of inputs_embeds.
@@ -928,7 +928,7 @@ class T5Model(T5PreTrainedModel):
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_past_key_values=None,
past_key_values=None,
use_cache=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
@@ -955,10 +955,16 @@ class T5Model(T5PreTrainedModel):
"""
if "decoder_past_key_value_states" in kwargs:
warnings.warn(
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
decoder_past_key_values = kwargs.pop("decoder_past_key_value_states")
past_key_values = kwargs.pop("decoder_past_key_value_states")
if "decoder_past_key_values" in kwargs:
warnings.warn(
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("decoder_past_key_values")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
use_cache = use_cache if use_cache is not None else self.config.use_cache
@@ -992,7 +998,7 @@ class T5Model(T5PreTrainedModel):
# If decoding with past key value states, only the last tokens
# should be given as an input
if decoder_past_key_values is not None:
if past_key_values is not None:
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
@@ -1003,7 +1009,7 @@ class T5Model(T5PreTrainedModel):
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_value_states=decoder_past_key_values,
past_key_value_states=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
@@ -1013,15 +1019,12 @@ class T5Model(T5PreTrainedModel):
return_dict=return_dict,
)
past = (encoder_outputs, decoder_outputs[1]) if use_cache is True else None
if not return_dict:
if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_past_key_values=past,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
@@ -1080,7 +1083,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_past_key_values=None,
past_key_values=None,
use_cache=None,
labels=None,
inputs_embeds=None,
@@ -1127,10 +1130,16 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
labels = kwargs.pop("lm_labels")
if "decoder_past_key_value_states" in kwargs:
warnings.warn(
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
decoder_past_key_values = kwargs.pop("decoder_past_key_value_states")
past_key_values = kwargs.pop("decoder_past_key_value_states")
if "decoder_past_key_values" in kwargs:
warnings.warn(
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("decoder_past_key_values")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
use_cache = use_cache if use_cache is not None else self.config.use_cache
@@ -1163,7 +1172,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
# If decoding with past key value states, only the last tokens
# should be given as an input
if decoder_past_key_values is not None:
if past_key_values is not None:
assert labels is None, "Decoder should not use cached key value states when training."
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
@@ -1175,7 +1184,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_value_states=decoder_past_key_values,
past_key_value_states=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
@@ -1197,17 +1206,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
past = (encoder_outputs, decoder_outputs[1]) if use_cache is True else None
if not return_dict:
if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
decoder_past_key_values=past,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
@@ -1215,14 +1221,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_attentions=encoder_outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
encoder_outputs, decoder_past_key_values = past
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs):
return {
"decoder_input_ids": input_ids,
"decoder_past_key_values": decoder_past_key_values,
"past_key_values": past,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"use_cache": use_cache,
@@ -1231,14 +1233,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
def _reorder_cache(self, past, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past[1] is None:
if past is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past
decoder_past = past[1]
past = (past[0],)
reordered_decoder_past = ()
for layer_past_states in decoder_past:
for layer_past_states in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
@@ -1252,4 +1252,4 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
assert len(reordered_layer_past_states) == len(layer_past_states)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return past + (reordered_decoder_past,)
return reordered_decoder_past

View File

@@ -431,7 +431,7 @@ class TFGPT2DoubleHeadsModelOutput(ModelOutput):
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
lm_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
@@ -454,7 +454,7 @@ class TFGPT2DoubleHeadsModelOutput(ModelOutput):
heads.
"""
lm_logits: tf.Tensor = None
logits: tf.Tensor = None
mc_logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
@@ -794,7 +794,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
return (lm_logits, mc_logits) + transformer_outputs[1:]
return TFGPT2DoubleHeadsModelOutput(
lm_logits=lm_logits,
logits=lm_logits,
mc_logits=mc_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,

View File

@@ -394,7 +394,7 @@ class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput):
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
lm_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
@@ -411,7 +411,7 @@ class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput):
heads.
"""
lm_logits: tf.Tensor = None
logits: tf.Tensor = None
mc_logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
@@ -719,7 +719,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
return (lm_logits, mc_logits) + transformer_outputs[1:]
return TFOpenAIGPTDoubleHeadsModelOutput(
lm_logits=lm_logits,
logits=lm_logits,
mc_logits=mc_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,

View File

@@ -113,13 +113,13 @@ class TFSeq2SeqModelOutput(ModelOutput):
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
If ``decoder_past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
If ``past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -147,7 +147,7 @@ class TFSeq2SeqModelOutput(ModelOutput):
"""
last_hidden_state: tf.Tensor = None
decoder_past_key_values: Optional[List[tf.Tensor]] = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None
@@ -259,12 +259,12 @@ class TFSeq2SeqLMOutput(ModelOutput):
Languaged modeling loss.
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -293,7 +293,7 @@ class TFSeq2SeqLMOutput(ModelOutput):
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
decoder_past_key_values: Optional[List[tf.Tensor]] = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None
@@ -366,12 +366,12 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -400,7 +400,7 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
decoder_past_key_values: Optional[List[tf.Tensor]] = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None
@@ -512,12 +512,12 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
Span-start scores (before SoftMax).
end_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -547,7 +547,7 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
loss: Optional[tf.Tensor] = None
start_logits: tf.Tensor = None
end_logits: tf.Tensor = None
decoder_past_key_values: Optional[List[tf.Tensor]] = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None

View File

@@ -437,15 +437,15 @@ class TFT5Block(tf.keras.layers.Layer):
):
if past_key_value_state is not None:
assert self.is_decoder, "Only decoder can use `past_key_value_states`"
expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4
assert self.is_decoder, "Only decoder can use `past_key_values`"
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format(
expected_num_past_key_value_states,
"2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "",
expected_num_past_key_values,
"2 (past / key) for cross attention" if expected_num_past_key_values == 4 else "",
len(past_key_value_state),
)
assert len(past_key_value_state) == expected_num_past_key_value_states, error_message
assert len(past_key_value_state) == expected_num_past_key_values, error_message
self_attn_past_key_value_state = past_key_value_state[:2]
cross_attn_past_key_value_state = past_key_value_state[2:]
@@ -586,11 +586,12 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
past_key_value_states=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
@@ -599,7 +600,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_attention_mask = inputs[3] if len(inputs) > 3 else encoder_attention_mask
inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds
head_mask = inputs[5] if len(inputs) > 5 else head_mask
past_key_value_states = inputs[6] if len(inputs) > 6 else past_key_value_states
past_key_values = inputs[6] if len(inputs) > 6 else past_key_values
use_cache = inputs[7] if len(inputs) > 7 else use_cache
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
@@ -611,13 +612,26 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
head_mask = inputs.get("head_mask", head_mask)
past_key_value_states = inputs.get("past_key_value_states", past_key_value_states)
past_key_values = inputs.get("past_key_values", past_key_values)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 10, "Too many inputs."
if "past_key_value_states" in inputs:
warnings.warn(
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = inputs.pop("past_key_value_states")
else:
input_ids = inputs
if "past_key_value_states" in kwargs:
warnings.warn(
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("past_key_value_states")
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
@@ -639,13 +653,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
batch_size, seq_length = input_shape
if past_key_value_states is not None:
if past_key_values is not None:
assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format(
input_shape, (batch_size, 1)
)
# required mask seq length can be calculated via length of past
# key value states and seq_length = 1 for the last token
mask_seq_length = shape_list(past_key_value_states[0][0])[2] + seq_length
mask_seq_length = shape_list(past_key_values[0][0])[2] + seq_length
else:
mask_seq_length = seq_length
@@ -655,9 +669,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_seq_length = shape_list(encoder_hidden_states)[1]
encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1)
# initialize past_key_value_states with `None` if past does not exist
if past_key_value_states is None:
past_key_value_states = [None] * len(self.block)
# initialize past_key_values with `None` if past does not exist
if past_key_values is None:
past_key_values = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
@@ -677,7 +691,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
)
causal_mask = tf.cast(causal_mask, dtype=tf.float32)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
if past_key_value_states[0] is not None:
if past_key_values[0] is not None:
extended_attention_mask = extended_attention_mask[:, :, -1:, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
@@ -726,7 +740,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
hidden_states = self.dropout(inputs_embeds, training=training)
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@@ -878,7 +892,7 @@ T5_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`).
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
@@ -889,13 +903,13 @@ T5_INPUTS_DOCSTRING = r"""
Used in the cross-attention of the decoder.
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
past_key_values (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up decoding.
If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids`
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`).
If `use_cache` is True, `past_key_values` are returned and can be used to speed up decoding (see `past_key_values`).
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`inputs` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `inputs` indices into associated vectors
@@ -969,7 +983,7 @@ class TFT5Model(TFT5PreTrainedModel):
encoder_outputs=None,
inputs_embeds=None,
head_mask=None,
decoder_past_key_value_states=None,
past_key_values=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_inputs_embeds=None,
@@ -978,6 +992,7 @@ class TFT5Model(TFT5PreTrainedModel):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
r"""
Returns:
@@ -999,7 +1014,7 @@ class TFT5Model(TFT5PreTrainedModel):
encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
head_mask = inputs[4] if len(inputs) > 4 else head_mask
decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
@@ -1017,7 +1032,7 @@ class TFT5Model(TFT5PreTrainedModel):
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
head_mask = inputs.get("head_mask", head_mask)
decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states)
past_key_values = inputs.get("past_key_values", past_key_values)
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
@@ -1026,9 +1041,23 @@ class TFT5Model(TFT5PreTrainedModel):
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 13, "Too many inputs."
if "past_key_value_states" in inputs:
warnings.warn(
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = inputs.pop("past_key_value_states")
else:
input_ids = inputs
if "past_key_value_states" in kwargs:
warnings.warn(
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("past_key_value_states")
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.return_dict
@@ -1054,7 +1083,7 @@ class TFT5Model(TFT5PreTrainedModel):
# If decoding with past key value states, only the last tokens
# should be given as an input
if decoder_past_key_value_states is not None:
if past_key_values is not None:
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
@@ -1069,7 +1098,7 @@ class TFT5Model(TFT5PreTrainedModel):
attention_mask,
decoder_inputs_embeds,
head_mask,
decoder_past_key_value_states,
past_key_values,
use_cache,
output_attentions,
output_hidden_states,
@@ -1103,7 +1132,7 @@ class TFT5Model(TFT5PreTrainedModel):
return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs[0],
decoder_past_key_values=past,
past_key_values=past,
decoder_hidden_states=decoder_outputs[2],
decoder_attentions=decoder_outputs[3],
encoder_last_hidden_state=encoder_outputs[0],
@@ -1164,7 +1193,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
encoder_outputs=None,
inputs_embeds=None,
head_mask=None,
decoder_past_key_value_states=None,
past_key_values=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_inputs_embeds=None,
@@ -1174,6 +1203,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@@ -1204,7 +1234,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
head_mask = inputs[4] if len(inputs) > 4 else head_mask
decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
@@ -1223,7 +1253,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
head_mask = inputs.get("head_mask", head_mask)
decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states)
past_key_values = inputs.get("past_key_values", past_key_values)
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
@@ -1233,9 +1263,23 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 14, "Too many inputs."
if "past_key_value_states" in inputs:
warnings.warn(
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = inputs.pop("past_key_value_states")
else:
input_ids = inputs
if "past_key_value_states" in kwargs:
warnings.warn(
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("past_key_value_states")
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.return_dict
@@ -1266,7 +1310,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# If decoding with past key value states, only the last tokens
# should be given as an input
if decoder_past_key_value_states is not None:
if past_key_values is not None:
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
@@ -1281,7 +1325,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
attention_mask,
decoder_inputs_embeds,
head_mask,
decoder_past_key_value_states,
past_key_values,
use_cache,
output_attentions,
output_hidden_states,
@@ -1324,7 +1368,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
return TFSeq2SeqLMOutput(
loss=loss,
logits=logits,
decoder_past_key_values=past,
past_key_values=past,
decoder_hidden_states=decoder_outputs[2],
decoder_attentions=decoder_outputs[3],
encoder_last_hidden_state=encoder_outputs[0],
@@ -1337,14 +1381,14 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# first step
if len(past) < 2:
encoder_outputs, decoder_past_key_value_states = past, None
encoder_outputs, past_key_values = past, None
else:
encoder_outputs, decoder_past_key_value_states = past[0], past[1]
encoder_outputs, past_key_values = past[0], past[1]
return {
"inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
"decoder_input_ids": inputs, # inputs are the decoder_input_ids
"decoder_past_key_value_states": decoder_past_key_value_states,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"use_cache": use_cache,

View File

@@ -661,6 +661,15 @@ class TransfoXLLMHeadModelOutput(ModelOutput):
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@property
def logits(self):
# prediciton scores are the output of the adaptive softmax, see
# the file `modeling_transfo_xl_utilities`. Since the adaptive
# softmax returns the log softmax value, `self.prediciton_scores`
# are strictly speaking not exactly `logits`, but behave the same
# way logits do.
return self.prediction_scores
TRANSFO_XL_START_DOCSTRING = r"""