[Generate] Facilitate PyTorch generate using ModelOutputs (#6735)
* fix generate for GPT2 Double Head * fix gpt2 double head model * fix bart / t5 * also add for no beam search * fix no beam search * fix encoder decoder * simplify t5 * simplify t5 * fix t5 tests * fix BART * fix transfo-xl * fix conflict * integrating sylvains and sams comments * fix tf past_decoder_key_values * fix enc dec test
This commit is contained in:
committed by
GitHub
parent
397f819615
commit
afc4ece462
@@ -20,6 +20,7 @@ import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .file_utils import ModelOutput
|
||||
from .utils import logging
|
||||
|
||||
|
||||
@@ -46,14 +47,6 @@ class GenerationMixin:
|
||||
"""
|
||||
return logits
|
||||
|
||||
def _use_cache(self, outputs, use_cache):
|
||||
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
|
||||
if len(outputs) <= 1 or use_cache is False:
|
||||
return False
|
||||
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
|
||||
"""
|
||||
Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
|
||||
@@ -137,7 +130,7 @@ class GenerationMixin:
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_start_token_id: Optional[int] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**model_specific_kwargs
|
||||
**model_kwargs
|
||||
) -> torch.LongTensor:
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
||||
@@ -208,7 +201,7 @@ class GenerationMixin:
|
||||
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||
speed up decoding.
|
||||
model_specific_kwargs:
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
|
||||
|
||||
Return:
|
||||
@@ -400,7 +393,7 @@ class GenerationMixin:
|
||||
|
||||
# get encoder and store encoder outputs
|
||||
encoder = self.get_encoder()
|
||||
encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
|
||||
encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
|
||||
|
||||
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
||||
if num_return_sequences > 1 or num_beams > 1:
|
||||
@@ -428,8 +421,8 @@ class GenerationMixin:
|
||||
cur_len = 1
|
||||
|
||||
assert (
|
||||
batch_size == encoder_outputs[0].shape[0]
|
||||
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
|
||||
batch_size == encoder_outputs.last_hidden_state.shape[0]
|
||||
), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
|
||||
|
||||
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
|
||||
expanded_batch_idxs = (
|
||||
@@ -439,11 +432,16 @@ class GenerationMixin:
|
||||
.view(-1)
|
||||
.to(input_ids.device)
|
||||
)
|
||||
|
||||
# expand encoder_outputs
|
||||
encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
|
||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
|
||||
0, expanded_batch_idxs
|
||||
)
|
||||
|
||||
# save encoder_outputs in `model_kwargs`
|
||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||
|
||||
else:
|
||||
encoder_outputs = None
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
assert (
|
||||
@@ -471,10 +469,9 @@ class GenerationMixin:
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
vocab_size=vocab_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
model_specific_kwargs=model_specific_kwargs,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
else:
|
||||
output = self._generate_no_beam_search(
|
||||
@@ -492,10 +489,9 @@ class GenerationMixin:
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
model_specific_kwargs=model_specific_kwargs,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -516,10 +512,9 @@ class GenerationMixin:
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
batch_size,
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
use_cache,
|
||||
model_specific_kwargs,
|
||||
model_kwargs,
|
||||
):
|
||||
"""Generate sequences for each example without beam search (num_beams == 1).
|
||||
All returned sequence are generated independantly.
|
||||
@@ -528,15 +523,14 @@ class GenerationMixin:
|
||||
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
||||
|
||||
past = (encoder_outputs, None) if encoder_outputs is not None else None
|
||||
|
||||
past = None
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
|
||||
)
|
||||
|
||||
outputs = self(**model_inputs)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
outputs = self(**model_inputs, return_dict=True)
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
scores = self.postprocess_next_token_scores(
|
||||
scores=next_token_logits,
|
||||
@@ -553,8 +547,10 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._use_cache(outputs, use_cache):
|
||||
past = outputs[1]
|
||||
if "past_key_values" in outputs:
|
||||
past = outputs.past_key_values
|
||||
elif "mems" in outputs:
|
||||
past = outputs.mems
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
@@ -621,10 +617,9 @@ class GenerationMixin:
|
||||
length_penalty,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
use_cache,
|
||||
model_specific_kwargs,
|
||||
model_kwargs,
|
||||
):
|
||||
"""Generate sequences for each example with beam search."""
|
||||
|
||||
@@ -643,21 +638,24 @@ class GenerationMixin:
|
||||
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||
|
||||
# cache compute states
|
||||
past = (encoder_outputs, None) if encoder_outputs is not None else None
|
||||
past = None
|
||||
|
||||
# done sentences
|
||||
done = [False for _ in range(batch_size)]
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
|
||||
)
|
||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
|
||||
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._use_cache(outputs, use_cache):
|
||||
past = outputs[1]
|
||||
if "past_key_values" in outputs:
|
||||
past = outputs.past_key_values
|
||||
elif "mems" in outputs:
|
||||
past = outputs.mems
|
||||
|
||||
if self.config.is_encoder_decoder and do_sample is False:
|
||||
# TODO (PVP) still a bit hacky here - there might be a better solution
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
|
||||
@@ -111,15 +111,15 @@ BART_INPUTS_DOCSTRING = r"""
|
||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
|
||||
See diagram 1 in the paper for more info on the default strategy
|
||||
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||
Can be used to speed up decoding.
|
||||
If ``decoder_past_key_value_states`` are used, the user can optionally input only the last
|
||||
If ``past_key_values`` are used, the user can optionally input only the last
|
||||
``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape
|
||||
:obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
If `use_cache` is True, ``decoder_past_key_values`` are returned and can be used to speed up decoding (see
|
||||
``decoder_past_key_values``).
|
||||
If `use_cache` is True, ``past_key_values`` are returned and can be used to speed up decoding (see
|
||||
``past_key_values``).
|
||||
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
@@ -502,7 +502,7 @@ class BartDecoder(nn.Module):
|
||||
encoder_padding_mask,
|
||||
decoder_padding_mask,
|
||||
decoder_causal_mask,
|
||||
decoder_past_key_values=None,
|
||||
past_key_values=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
@@ -519,7 +519,7 @@ class BartDecoder(nn.Module):
|
||||
encoder_hidden_states: output from the encoder, used for
|
||||
encoder-side attention
|
||||
encoder_padding_mask: for ignoring pad tokens
|
||||
decoder_past_key_values (dict or None): dictionary used for storing state during generation
|
||||
past_key_values (dict or None): dictionary used for storing state during generation
|
||||
|
||||
Returns:
|
||||
BaseModelOutputWithPast or tuple:
|
||||
@@ -530,10 +530,16 @@ class BartDecoder(nn.Module):
|
||||
"""
|
||||
if "decoder_cached_states" in unused:
|
||||
warnings.warn(
|
||||
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
||||
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
decoder_past_key_values = unused.pop("decoder_cached_states")
|
||||
past_key_values = unused.pop("decoder_cached_states")
|
||||
if "decoder_past_key_values" in unused:
|
||||
warnings.warn(
|
||||
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = unused.pop("decoder_past_key_values")
|
||||
|
||||
# check attention mask and invert
|
||||
if encoder_padding_mask is not None:
|
||||
@@ -568,7 +574,7 @@ class BartDecoder(nn.Module):
|
||||
if self.training and (dropout_probability < self.layerdrop):
|
||||
continue
|
||||
|
||||
layer_state = decoder_past_key_values[idx] if decoder_past_key_values is not None else None
|
||||
layer_state = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
x, layer_self_attn, layer_past = decoder_layer(
|
||||
x,
|
||||
@@ -594,10 +600,7 @@ class BartDecoder(nn.Module):
|
||||
x = x.transpose(0, 1)
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
||||
|
||||
if use_cache:
|
||||
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
|
||||
else:
|
||||
next_cache = None
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
@@ -869,13 +872,19 @@ class BartModel(PretrainedBartModel):
|
||||
decoder_input_ids=None,
|
||||
encoder_outputs: Optional[Tuple] = None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_past_key_values=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
if "decoder_past_key_values" in kwargs:
|
||||
warnings.warn(
|
||||
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = kwargs.pop("decoder_past_key_values")
|
||||
|
||||
if decoder_input_ids is None:
|
||||
use_cache = False
|
||||
@@ -924,7 +933,7 @@ class BartModel(PretrainedBartModel):
|
||||
attention_mask,
|
||||
decoder_padding_mask,
|
||||
decoder_causal_mask=causal_mask,
|
||||
decoder_past_key_values=decoder_past_key_values,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -936,7 +945,7 @@ class BartModel(PretrainedBartModel):
|
||||
|
||||
return Seq2SeqModelOutput(
|
||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||
decoder_past_key_values=decoder_outputs.past_key_values,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
@@ -994,7 +1003,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
encoder_outputs=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_past_key_values=None,
|
||||
past_key_values=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
@@ -1037,10 +1046,16 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
labels = unused.pop("lm_labels")
|
||||
if "decoder_cached_states" in unused:
|
||||
warnings.warn(
|
||||
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
||||
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
decoder_past_key_values = unused.pop("decoder_cached_states")
|
||||
past_key_values = unused.pop("decoder_cached_states")
|
||||
if "decoder_past_key_values" in unused:
|
||||
warnings.warn(
|
||||
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = unused.pop("decoder_past_key_values")
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if labels is not None:
|
||||
@@ -1054,7 +1069,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
decoder_past_key_values=decoder_past_key_values,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -1075,7 +1090,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
return Seq2SeqLMOutput(
|
||||
loss=masked_lm_loss,
|
||||
logits=lm_logits,
|
||||
decoder_past_key_values=outputs.decoder_past_key_values,
|
||||
past_key_values=outputs.past_key_values,
|
||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||
decoder_attentions=outputs.decoder_attentions,
|
||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||
@@ -1083,14 +1098,13 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
|
||||
assert past is not None, "past has to be defined for encoder_outputs"
|
||||
|
||||
encoder_outputs, decoder_past_key_values = past
|
||||
def prepare_inputs_for_generation(
|
||||
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
|
||||
):
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"decoder_past_key_values": decoder_past_key_values,
|
||||
"past_key_values": past,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
@@ -1109,20 +1123,14 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
((enc_out, enc_mask), decoder_past_key_values) = past
|
||||
reordered_past = []
|
||||
for layer_past in decoder_past_key_values:
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||||
layer_past_new = {
|
||||
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
||||
}
|
||||
reordered_past.append(layer_past_new)
|
||||
|
||||
new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx)
|
||||
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
|
||||
|
||||
past = ((new_enc_out, new_enc_mask), reordered_past)
|
||||
return past
|
||||
return reordered_past
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.encoder
|
||||
@@ -1208,7 +1216,7 @@ class BartForSequenceClassification(PretrainedBartModel):
|
||||
return Seq2SeqSequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
decoder_past_key_values=outputs.decoder_past_key_values,
|
||||
past_key_values=outputs.past_key_values,
|
||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||
decoder_attentions=outputs.decoder_attentions,
|
||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||
@@ -1316,7 +1324,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
|
||||
loss=total_loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
decoder_past_key_values=outputs.decoder_past_key_values,
|
||||
past_key_values=outputs.past_key_values,
|
||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||
decoder_attentions=outputs.decoder_attentions,
|
||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||
|
||||
@@ -19,13 +19,79 @@ from typing import Optional
|
||||
|
||||
from .configuration_encoder_decoder import EncoderDecoderConfig
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable, replace_return_docstrings
|
||||
from .modeling_outputs import Seq2SeqLMOutput
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "EncoderDecoderConfig"
|
||||
|
||||
ENCODER_DECODER_START_DOCSTRING = r"""
|
||||
This class can be used to inialize a sequence-to-sequnece model with any pretrained autoencoding model as the encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via :meth:`~transformers.AutoModel.from_pretrained` function and the decoder is loaded via :meth:`~transformers.AutoModelForCausalLM.from_pretrained` function.
|
||||
Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream generative task, *i.e.* summarization.
|
||||
|
||||
The effectiveness of initializing sequence-to-sequence models with pre-trained checkpoints for sequence generation tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks <https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
|
||||
Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
|
||||
|
||||
After such an Encoder Decoder model has been trained / fine-tuned, it can be saved / loaded just like any other models (see Examples for more information).
|
||||
|
||||
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#module>`__ sub-class. Use it as a
|
||||
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
|
||||
ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary for the encoder.
|
||||
Indices can be obtained using :class:`~transformers.PretrainedTokenizer`.
|
||||
See :meth:`~transformers.PreTrainedTokenizer.encode` and
|
||||
:meth:`~transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices for the encoder.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
encoder_outputs (:obj:`tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
|
||||
This tuple must consist of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`)
|
||||
`last_hidden_state` (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`) is a tensor of hidden-states at the output of the last layer of the encoder.
|
||||
Used in the cross-attention of the decoder.
|
||||
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Provide for sequence to sequence training to the decoder.
|
||||
Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the masked language modeling loss for the decoder.
|
||||
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the model will return a :class:`~transformers.file_utils.Seq2SeqLMOutput` instead of a
|
||||
plain tuple.
|
||||
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
|
||||
- Without a prefix which will be input as ``**encoder_kwargs`` for the encoder forward function.
|
||||
- With a `decoder_` prefix which will be input as ``**decoder_kwargs`` for the decoder forward function.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
|
||||
class EncoderDecoderModel(PreTrainedModel):
|
||||
r"""
|
||||
:class:`~transformers.EncoderDecoder` is a generic model class that will be
|
||||
@@ -206,6 +272,8 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
||||
return cls(encoder=encoder, decoder=decoder, config=config)
|
||||
|
||||
@add_start_docstrings_to_callable(ENCODER_DECODER_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -216,47 +284,11 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
decoder_inputs_embeds=None,
|
||||
labels=None,
|
||||
return_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary for the encoder.
|
||||
Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices for the encoder.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
|
||||
Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
|
||||
`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
|
||||
Used in the cross-attention of the decoder.
|
||||
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Provide for sequence to sequence training to the decoder.
|
||||
Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the masked language modeling loss for the decoder.
|
||||
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
|
||||
- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
|
||||
- With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function.
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
@@ -264,19 +296,25 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
|
||||
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
|
||||
|
||||
>>> # forward
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
|
||||
|
||||
>>> # training
|
||||
>>> loss, outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)[:2]
|
||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids, return_dict=True)
|
||||
>>> loss, logits = outputs.loss, outputs.logits
|
||||
|
||||
>>> # save and load from pretrained
|
||||
>>> model.save_pretrained("bert2bert")
|
||||
>>> model = EncoderDecoderModel.from_pretrained("bert2bert")
|
||||
|
||||
>>> # generation
|
||||
>>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)
|
||||
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
|
||||
|
||||
@@ -289,7 +327,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
return_dict=False,
|
||||
return_dict=return_dict,
|
||||
**kwargs_encoder,
|
||||
)
|
||||
|
||||
@@ -303,23 +341,28 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
return_dict=False,
|
||||
return_dict=return_dict,
|
||||
**kwargs_decoder,
|
||||
)
|
||||
|
||||
# TODO(PVP): currently it is not possible to use `past`
|
||||
# with the encoder/decoder framework -> should be implemented
|
||||
if not return_dict:
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
return Seq2SeqLMOutput(
|
||||
loss=decoder_outputs.loss,
|
||||
logits=decoder_outputs.logits,
|
||||
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
|
||||
assert past is not None, "past has to be defined for encoder_outputs"
|
||||
|
||||
# first step
|
||||
if type(past) is tuple:
|
||||
encoder_outputs, _ = past
|
||||
else:
|
||||
encoder_outputs = (past,)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, encoder_outputs, **kwargs):
|
||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
|
||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||
input_dict = {
|
||||
@@ -335,7 +378,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
input_dict["decoder_use_cache"] = decoder_inputs["use_cache"]
|
||||
|
||||
if "past_key_values" in decoder_inputs:
|
||||
input_dict["decoder_past_key_values"] = decoder_inputs["past_key_values"]
|
||||
input_dict["past_key_values"] = decoder_inputs["past_key_values"]
|
||||
|
||||
return input_dict
|
||||
|
||||
|
||||
@@ -353,11 +353,11 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
|
||||
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||
|
||||
Args:
|
||||
lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
|
||||
Language modeling loss.
|
||||
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
|
||||
Multiple choice classification loss.
|
||||
lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
||||
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
||||
@@ -380,9 +380,9 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
|
||||
heads.
|
||||
"""
|
||||
|
||||
lm_loss: Optional[torch.FloatTensor] = None
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
mc_loss: Optional[torch.FloatTensor] = None
|
||||
lm_logits: torch.FloatTensor = None
|
||||
logits: torch.FloatTensor = None
|
||||
mc_logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
@@ -777,6 +777,17 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
}
|
||||
|
||||
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@@ -893,9 +904,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||
|
||||
return GPT2DoubleHeadsModelOutput(
|
||||
lm_loss=lm_loss,
|
||||
loss=lm_loss,
|
||||
mc_loss=mc_loss,
|
||||
lm_logits=lm_logits,
|
||||
logits=lm_logits,
|
||||
mc_logits=mc_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
|
||||
@@ -300,11 +300,11 @@ class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
|
||||
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||
|
||||
Args:
|
||||
lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
|
||||
Language modeling loss.
|
||||
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
|
||||
Multiple choice classification loss.
|
||||
lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
||||
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
||||
@@ -321,9 +321,9 @@ class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
|
||||
heads.
|
||||
"""
|
||||
|
||||
lm_loss: Optional[torch.FloatTensor] = None
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
mc_loss: Optional[torch.FloatTensor] = None
|
||||
lm_logits: torch.FloatTensor = None
|
||||
logits: torch.FloatTensor = None
|
||||
mc_logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
@@ -713,9 +713,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||
|
||||
return OpenAIGPTDoubleHeadsModelOutput(
|
||||
lm_loss=lm_loss,
|
||||
loss=lm_loss,
|
||||
mc_loss=mc_loss,
|
||||
lm_logits=lm_logits,
|
||||
logits=lm_logits,
|
||||
mc_logits=mc_logits,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
|
||||
@@ -109,13 +109,13 @@ class Seq2SeqModelOutput(ModelOutput):
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
||||
|
||||
If ``decoder_past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
|
||||
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
If ``past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
|
||||
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
||||
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -143,7 +143,7 @@ class Seq2SeqModelOutput(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor
|
||||
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
@@ -255,12 +255,12 @@ class Seq2SeqLMOutput(ModelOutput):
|
||||
Languaged modeling loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
||||
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -289,7 +289,7 @@ class Seq2SeqLMOutput(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
@@ -365,12 +365,12 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
||||
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -399,7 +399,7 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
@@ -511,12 +511,12 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
||||
Span-start scores (before SoftMax).
|
||||
end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||
Span-end scores (before SoftMax).
|
||||
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
||||
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -546,7 +546,7 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
start_logits: torch.FloatTensor = None
|
||||
end_logits: torch.FloatTensor = None
|
||||
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
|
||||
@@ -838,27 +838,27 @@ T5_INPUTS_DOCSTRING = r"""
|
||||
Used in the cross-attention of the decoder.
|
||||
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
|
||||
If `decoder_past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_values`).
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
|
||||
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
|
||||
`T5 Training <./t5.html#training>`__. If decoder_input_ids and decoder_inputs_embeds are both None,
|
||||
decoder_input_ids takes the value of input_ids.
|
||||
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||
decoder_past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||
Can be used to speed up decoding.
|
||||
If `decoder_past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
If `use_cache` is True, `decoder_past_key_values` are returned and can be used to speed up decoding (see `decoder_past_key_values`).
|
||||
If `use_cache` is True, `past_key_values` are returned and can be used to speed up decoding (see `past_key_values`).
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
|
||||
If `decoder_past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_values`).
|
||||
If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`).
|
||||
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix. If decoder_input_ids and decoder_inputs_embeds are both None,
|
||||
decoder_inputs_embeds takes the value of inputs_embeds.
|
||||
@@ -928,7 +928,7 @@ class T5Model(T5PreTrainedModel):
|
||||
encoder_outputs=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_past_key_values=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
@@ -955,10 +955,16 @@ class T5Model(T5PreTrainedModel):
|
||||
"""
|
||||
if "decoder_past_key_value_states" in kwargs:
|
||||
warnings.warn(
|
||||
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
||||
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
decoder_past_key_values = kwargs.pop("decoder_past_key_value_states")
|
||||
past_key_values = kwargs.pop("decoder_past_key_value_states")
|
||||
if "decoder_past_key_values" in kwargs:
|
||||
warnings.warn(
|
||||
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = kwargs.pop("decoder_past_key_values")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
@@ -992,7 +998,7 @@ class T5Model(T5PreTrainedModel):
|
||||
|
||||
# If decoding with past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if decoder_past_key_values is not None:
|
||||
if past_key_values is not None:
|
||||
if decoder_input_ids is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
if decoder_inputs_embeds is not None:
|
||||
@@ -1003,7 +1009,7 @@ class T5Model(T5PreTrainedModel):
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
past_key_value_states=decoder_past_key_values,
|
||||
past_key_value_states=past_key_values,
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
@@ -1013,15 +1019,12 @@ class T5Model(T5PreTrainedModel):
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
past = (encoder_outputs, decoder_outputs[1]) if use_cache is True else None
|
||||
if not return_dict:
|
||||
if past is not None:
|
||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
return Seq2SeqModelOutput(
|
||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||
decoder_past_key_values=past,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
@@ -1080,7 +1083,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
encoder_outputs=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_past_key_values=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
labels=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1127,10 +1130,16 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
labels = kwargs.pop("lm_labels")
|
||||
if "decoder_past_key_value_states" in kwargs:
|
||||
warnings.warn(
|
||||
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
||||
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
decoder_past_key_values = kwargs.pop("decoder_past_key_value_states")
|
||||
past_key_values = kwargs.pop("decoder_past_key_value_states")
|
||||
if "decoder_past_key_values" in kwargs:
|
||||
warnings.warn(
|
||||
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = kwargs.pop("decoder_past_key_values")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
@@ -1163,7 +1172,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
|
||||
# If decoding with past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if decoder_past_key_values is not None:
|
||||
if past_key_values is not None:
|
||||
assert labels is None, "Decoder should not use cached key value states when training."
|
||||
if decoder_input_ids is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
@@ -1175,7 +1184,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
past_key_value_states=decoder_past_key_values,
|
||||
past_key_value_states=past_key_values,
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
@@ -1197,17 +1206,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
||||
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
||||
|
||||
past = (encoder_outputs, decoder_outputs[1]) if use_cache is True else None
|
||||
if not return_dict:
|
||||
if past is not None:
|
||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
||||
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return Seq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
decoder_past_key_values=past,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
@@ -1215,14 +1221,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs):
|
||||
assert past is not None, "past has to be defined for encoder_outputs"
|
||||
|
||||
encoder_outputs, decoder_past_key_values = past
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs):
|
||||
return {
|
||||
"decoder_input_ids": input_ids,
|
||||
"decoder_past_key_values": decoder_past_key_values,
|
||||
"past_key_values": past,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"use_cache": use_cache,
|
||||
@@ -1231,14 +1233,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
# if decoder past is not included in output
|
||||
# speedy decoding is disabled and no need to reorder
|
||||
if past[1] is None:
|
||||
if past is None:
|
||||
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
|
||||
return past
|
||||
|
||||
decoder_past = past[1]
|
||||
past = (past[0],)
|
||||
reordered_decoder_past = ()
|
||||
for layer_past_states in decoder_past:
|
||||
for layer_past_states in past:
|
||||
# get the correct batch idx from layer past batch dim
|
||||
# batch dim of `past` is at 2nd position
|
||||
reordered_layer_past_states = ()
|
||||
@@ -1252,4 +1252,4 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
assert len(reordered_layer_past_states) == len(layer_past_states)
|
||||
|
||||
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
|
||||
return past + (reordered_decoder_past,)
|
||||
return reordered_decoder_past
|
||||
|
||||
@@ -431,7 +431,7 @@ class TFGPT2DoubleHeadsModelOutput(ModelOutput):
|
||||
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||
|
||||
Args:
|
||||
lm_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
||||
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
mc_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`):
|
||||
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
||||
@@ -454,7 +454,7 @@ class TFGPT2DoubleHeadsModelOutput(ModelOutput):
|
||||
heads.
|
||||
"""
|
||||
|
||||
lm_logits: tf.Tensor = None
|
||||
logits: tf.Tensor = None
|
||||
mc_logits: tf.Tensor = None
|
||||
past_key_values: Optional[List[tf.Tensor]] = None
|
||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
@@ -794,7 +794,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
return (lm_logits, mc_logits) + transformer_outputs[1:]
|
||||
|
||||
return TFGPT2DoubleHeadsModelOutput(
|
||||
lm_logits=lm_logits,
|
||||
logits=lm_logits,
|
||||
mc_logits=mc_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
|
||||
@@ -394,7 +394,7 @@ class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput):
|
||||
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||
|
||||
Args:
|
||||
lm_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
||||
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
mc_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`):
|
||||
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
||||
@@ -411,7 +411,7 @@ class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput):
|
||||
heads.
|
||||
"""
|
||||
|
||||
lm_logits: tf.Tensor = None
|
||||
logits: tf.Tensor = None
|
||||
mc_logits: tf.Tensor = None
|
||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
@@ -719,7 +719,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
||||
return (lm_logits, mc_logits) + transformer_outputs[1:]
|
||||
|
||||
return TFOpenAIGPTDoubleHeadsModelOutput(
|
||||
lm_logits=lm_logits,
|
||||
logits=lm_logits,
|
||||
mc_logits=mc_logits,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
|
||||
@@ -113,13 +113,13 @@ class TFSeq2SeqModelOutput(ModelOutput):
|
||||
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
||||
|
||||
If ``decoder_past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
|
||||
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
If ``past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
|
||||
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
||||
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -147,7 +147,7 @@ class TFSeq2SeqModelOutput(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: tf.Tensor = None
|
||||
decoder_past_key_values: Optional[List[tf.Tensor]] = None
|
||||
past_key_values: Optional[List[tf.Tensor]] = None
|
||||
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
||||
@@ -259,12 +259,12 @@ class TFSeq2SeqLMOutput(ModelOutput):
|
||||
Languaged modeling loss.
|
||||
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
||||
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -293,7 +293,7 @@ class TFSeq2SeqLMOutput(ModelOutput):
|
||||
|
||||
loss: Optional[tf.Tensor] = None
|
||||
logits: tf.Tensor = None
|
||||
decoder_past_key_values: Optional[List[tf.Tensor]] = None
|
||||
past_key_values: Optional[List[tf.Tensor]] = None
|
||||
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
||||
@@ -366,12 +366,12 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
||||
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -400,7 +400,7 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
|
||||
|
||||
loss: Optional[tf.Tensor] = None
|
||||
logits: tf.Tensor = None
|
||||
decoder_past_key_values: Optional[List[tf.Tensor]] = None
|
||||
past_key_values: Optional[List[tf.Tensor]] = None
|
||||
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
||||
@@ -512,12 +512,12 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
||||
Span-start scores (before SoftMax).
|
||||
end_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||
Span-end scores (before SoftMax).
|
||||
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
||||
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -547,7 +547,7 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
||||
loss: Optional[tf.Tensor] = None
|
||||
start_logits: tf.Tensor = None
|
||||
end_logits: tf.Tensor = None
|
||||
decoder_past_key_values: Optional[List[tf.Tensor]] = None
|
||||
past_key_values: Optional[List[tf.Tensor]] = None
|
||||
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
||||
|
||||
@@ -437,15 +437,15 @@ class TFT5Block(tf.keras.layers.Layer):
|
||||
):
|
||||
|
||||
if past_key_value_state is not None:
|
||||
assert self.is_decoder, "Only decoder can use `past_key_value_states`"
|
||||
expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4
|
||||
assert self.is_decoder, "Only decoder can use `past_key_values`"
|
||||
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
|
||||
|
||||
error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format(
|
||||
expected_num_past_key_value_states,
|
||||
"2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "",
|
||||
expected_num_past_key_values,
|
||||
"2 (past / key) for cross attention" if expected_num_past_key_values == 4 else "",
|
||||
len(past_key_value_state),
|
||||
)
|
||||
assert len(past_key_value_state) == expected_num_past_key_value_states, error_message
|
||||
assert len(past_key_value_state) == expected_num_past_key_values, error_message
|
||||
|
||||
self_attn_past_key_value_state = past_key_value_state[:2]
|
||||
cross_attn_past_key_value_state = past_key_value_state[2:]
|
||||
@@ -586,11 +586,12 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
encoder_attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
head_mask=None,
|
||||
past_key_value_states=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
@@ -599,7 +600,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
encoder_attention_mask = inputs[3] if len(inputs) > 3 else encoder_attention_mask
|
||||
inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
past_key_value_states = inputs[6] if len(inputs) > 6 else past_key_value_states
|
||||
past_key_values = inputs[6] if len(inputs) > 6 else past_key_values
|
||||
use_cache = inputs[7] if len(inputs) > 7 else use_cache
|
||||
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
|
||||
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
|
||||
@@ -611,13 +612,26 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
past_key_value_states = inputs.get("past_key_value_states", past_key_value_states)
|
||||
past_key_values = inputs.get("past_key_values", past_key_values)
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
|
||||
if "past_key_value_states" in inputs:
|
||||
warnings.warn(
|
||||
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = inputs.pop("past_key_value_states")
|
||||
else:
|
||||
input_ids = inputs
|
||||
if "past_key_value_states" in kwargs:
|
||||
warnings.warn(
|
||||
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = kwargs.pop("past_key_value_states")
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
@@ -639,13 +653,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
if past_key_value_states is not None:
|
||||
if past_key_values is not None:
|
||||
assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format(
|
||||
input_shape, (batch_size, 1)
|
||||
)
|
||||
# required mask seq length can be calculated via length of past
|
||||
# key value states and seq_length = 1 for the last token
|
||||
mask_seq_length = shape_list(past_key_value_states[0][0])[2] + seq_length
|
||||
mask_seq_length = shape_list(past_key_values[0][0])[2] + seq_length
|
||||
else:
|
||||
mask_seq_length = seq_length
|
||||
|
||||
@@ -655,9 +669,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
encoder_seq_length = shape_list(encoder_hidden_states)[1]
|
||||
encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1)
|
||||
|
||||
# initialize past_key_value_states with `None` if past does not exist
|
||||
if past_key_value_states is None:
|
||||
past_key_value_states = [None] * len(self.block)
|
||||
# initialize past_key_values with `None` if past does not exist
|
||||
if past_key_values is None:
|
||||
past_key_values = [None] * len(self.block)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
@@ -677,7 +691,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
)
|
||||
causal_mask = tf.cast(causal_mask, dtype=tf.float32)
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
if past_key_value_states[0] is not None:
|
||||
if past_key_values[0] is not None:
|
||||
extended_attention_mask = extended_attention_mask[:, :, -1:, :]
|
||||
else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
@@ -726,7 +740,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
|
||||
hidden_states = self.dropout(inputs_embeds, training=training)
|
||||
|
||||
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
|
||||
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_values)):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
@@ -878,7 +892,7 @@ T5_INPUTS_DOCSTRING = r"""
|
||||
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
|
||||
If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`).
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
|
||||
attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
@@ -889,13 +903,13 @@ T5_INPUTS_DOCSTRING = r"""
|
||||
Used in the cross-attention of the decoder.
|
||||
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||
decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
past_key_values (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||
Can be used to speed up decoding.
|
||||
If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids`
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`).
|
||||
If `use_cache` is True, `past_key_values` are returned and can be used to speed up decoding (see `past_key_values`).
|
||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`inputs` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `inputs` indices into associated vectors
|
||||
@@ -969,7 +983,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
encoder_outputs=None,
|
||||
inputs_embeds=None,
|
||||
head_mask=None,
|
||||
decoder_past_key_value_states=None,
|
||||
past_key_values=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_inputs_embeds=None,
|
||||
@@ -978,6 +992,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
@@ -999,7 +1014,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states
|
||||
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
|
||||
decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
|
||||
decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
|
||||
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
|
||||
@@ -1017,7 +1032,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states)
|
||||
past_key_values = inputs.get("past_key_values", past_key_values)
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
|
||||
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
|
||||
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
|
||||
@@ -1026,9 +1041,23 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 13, "Too many inputs."
|
||||
|
||||
if "past_key_value_states" in inputs:
|
||||
warnings.warn(
|
||||
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = inputs.pop("past_key_value_states")
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
if "past_key_value_states" in kwargs:
|
||||
warnings.warn(
|
||||
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = kwargs.pop("past_key_value_states")
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
@@ -1054,7 +1083,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
|
||||
# If decoding with past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if decoder_past_key_value_states is not None:
|
||||
if past_key_values is not None:
|
||||
if decoder_input_ids is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
if decoder_inputs_embeds is not None:
|
||||
@@ -1069,7 +1098,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
attention_mask,
|
||||
decoder_inputs_embeds,
|
||||
head_mask,
|
||||
decoder_past_key_value_states,
|
||||
past_key_values,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
@@ -1103,7 +1132,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
|
||||
return TFSeq2SeqModelOutput(
|
||||
last_hidden_state=decoder_outputs[0],
|
||||
decoder_past_key_values=past,
|
||||
past_key_values=past,
|
||||
decoder_hidden_states=decoder_outputs[2],
|
||||
decoder_attentions=decoder_outputs[3],
|
||||
encoder_last_hidden_state=encoder_outputs[0],
|
||||
@@ -1164,7 +1193,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
encoder_outputs=None,
|
||||
inputs_embeds=None,
|
||||
head_mask=None,
|
||||
decoder_past_key_value_states=None,
|
||||
past_key_values=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_inputs_embeds=None,
|
||||
@@ -1174,6 +1203,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -1204,7 +1234,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states
|
||||
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
|
||||
decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
|
||||
decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
|
||||
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
|
||||
@@ -1223,7 +1253,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states)
|
||||
past_key_values = inputs.get("past_key_values", past_key_values)
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
|
||||
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
|
||||
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
|
||||
@@ -1233,9 +1263,23 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 14, "Too many inputs."
|
||||
|
||||
if "past_key_value_states" in inputs:
|
||||
warnings.warn(
|
||||
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = inputs.pop("past_key_value_states")
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
if "past_key_value_states" in kwargs:
|
||||
warnings.warn(
|
||||
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = kwargs.pop("past_key_value_states")
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
@@ -1266,7 +1310,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
|
||||
# If decoding with past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if decoder_past_key_value_states is not None:
|
||||
if past_key_values is not None:
|
||||
if decoder_input_ids is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
if decoder_inputs_embeds is not None:
|
||||
@@ -1281,7 +1325,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
attention_mask,
|
||||
decoder_inputs_embeds,
|
||||
head_mask,
|
||||
decoder_past_key_value_states,
|
||||
past_key_values,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
@@ -1324,7 +1368,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
return TFSeq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
decoder_past_key_values=past,
|
||||
past_key_values=past,
|
||||
decoder_hidden_states=decoder_outputs[2],
|
||||
decoder_attentions=decoder_outputs[3],
|
||||
encoder_last_hidden_state=encoder_outputs[0],
|
||||
@@ -1337,14 +1381,14 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
|
||||
# first step
|
||||
if len(past) < 2:
|
||||
encoder_outputs, decoder_past_key_value_states = past, None
|
||||
encoder_outputs, past_key_values = past, None
|
||||
else:
|
||||
encoder_outputs, decoder_past_key_value_states = past[0], past[1]
|
||||
encoder_outputs, past_key_values = past[0], past[1]
|
||||
|
||||
return {
|
||||
"inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
|
||||
"decoder_input_ids": inputs, # inputs are the decoder_input_ids
|
||||
"decoder_past_key_value_states": decoder_past_key_value_states,
|
||||
"past_key_values": past_key_values,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"use_cache": use_cache,
|
||||
|
||||
@@ -661,6 +661,15 @@ class TransfoXLLMHeadModelOutput(ModelOutput):
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
@property
|
||||
def logits(self):
|
||||
# prediciton scores are the output of the adaptive softmax, see
|
||||
# the file `modeling_transfo_xl_utilities`. Since the adaptive
|
||||
# softmax returns the log softmax value, `self.prediciton_scores`
|
||||
# are strictly speaking not exactly `logits`, but behave the same
|
||||
# way logits do.
|
||||
return self.prediction_scores
|
||||
|
||||
|
||||
TRANSFO_XL_START_DOCSTRING = r"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user