[Generate] Facilitate PyTorch generate using ModelOutputs (#6735)
* fix generate for GPT2 Double Head * fix gpt2 double head model * fix bart / t5 * also add for no beam search * fix no beam search * fix encoder decoder * simplify t5 * simplify t5 * fix t5 tests * fix BART * fix transfo-xl * fix conflict * integrating sylvains and sams comments * fix tf past_decoder_key_values * fix enc dec test
This commit is contained in:
committed by
GitHub
parent
397f819615
commit
afc4ece462
@@ -1,12 +1,13 @@
|
|||||||
Encoder Decoder Models
|
Encoder Decoder Models
|
||||||
------------------------
|
------------------------
|
||||||
|
|
||||||
This class can wrap an encoder model, such as ``BertModel`` and a decoder modeling with a language modeling head, such as ``BertForMaskedLM`` into a encoder-decoder model.
|
The :class:`~transformers.EncoderDecoderModel` can be used to initialize a sequence-to-sequence model with any pre-trained autoencoding model as the encoder and any pre-trained autoregressive model as the decoder.
|
||||||
|
|
||||||
The ``EncoderDecoderModel`` class allows to instantiate a encoder decoder model using the ``from_encoder_decoder_pretrain`` class method taking a pretrained encoder and pretrained decoder model as an input.
|
The effectiveness of initializing sequence-to-sequence models with pre-trained checkpoints for sequence generation tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks <https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
|
||||||
The ``EncoderDecoderModel`` is saved using the standard ``save_pretrained()`` method and can also again be loaded using the standard ``from_pretrained()`` method.
|
|
||||||
|
|
||||||
An application of this architecture could be *summarization* using two pretrained Bert models as is shown in the paper: `Text Summarization with Pretrained Encoders <https://arxiv.org/abs/1910.13461>`_ by Yang Liu and Mirella Lapata.
|
After such an :class:`~transformers.EncoderDecoderModel` has been trained / fine-tuned, it can be saved / loaded just like any other models (see Examples for more information).
|
||||||
|
|
||||||
|
An application of this architecture could be to leverage two pre-trained :obj:`transformers.BertModel` models as the encoder and decoder for a summarization model as was shown in: `Text Summarization with Pretrained Encoders <https://arxiv.org/abs/1910.13461>`_ by Yang Liu and Mirella Lapata.
|
||||||
|
|
||||||
|
|
||||||
``EncoderDecoderConfig``
|
``EncoderDecoderConfig``
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import torch
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .file_utils import ModelOutput
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -46,14 +47,6 @@ class GenerationMixin:
|
|||||||
"""
|
"""
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def _use_cache(self, outputs, use_cache):
|
|
||||||
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
|
|
||||||
if len(outputs) <= 1 or use_cache is False:
|
|
||||||
return False
|
|
||||||
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
|
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
|
||||||
"""
|
"""
|
||||||
Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
|
Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
|
||||||
@@ -137,7 +130,7 @@ class GenerationMixin:
|
|||||||
attention_mask: Optional[torch.LongTensor] = None,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
decoder_start_token_id: Optional[int] = None,
|
decoder_start_token_id: Optional[int] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
**model_specific_kwargs
|
**model_kwargs
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
r"""
|
r"""
|
||||||
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
||||||
@@ -208,7 +201,7 @@ class GenerationMixin:
|
|||||||
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||||
speed up decoding.
|
speed up decoding.
|
||||||
model_specific_kwargs:
|
model_kwargs:
|
||||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
|
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
@@ -400,7 +393,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# get encoder and store encoder outputs
|
# get encoder and store encoder outputs
|
||||||
encoder = self.get_encoder()
|
encoder = self.get_encoder()
|
||||||
encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
|
encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
|
||||||
|
|
||||||
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
||||||
if num_return_sequences > 1 or num_beams > 1:
|
if num_return_sequences > 1 or num_beams > 1:
|
||||||
@@ -428,8 +421,8 @@ class GenerationMixin:
|
|||||||
cur_len = 1
|
cur_len = 1
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
batch_size == encoder_outputs[0].shape[0]
|
batch_size == encoder_outputs.last_hidden_state.shape[0]
|
||||||
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
|
), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
|
||||||
|
|
||||||
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
|
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
|
||||||
expanded_batch_idxs = (
|
expanded_batch_idxs = (
|
||||||
@@ -439,11 +432,16 @@ class GenerationMixin:
|
|||||||
.view(-1)
|
.view(-1)
|
||||||
.to(input_ids.device)
|
.to(input_ids.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
# expand encoder_outputs
|
# expand encoder_outputs
|
||||||
encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
|
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
|
||||||
|
0, expanded_batch_idxs
|
||||||
|
)
|
||||||
|
|
||||||
|
# save encoder_outputs in `model_kwargs`
|
||||||
|
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||||
|
|
||||||
else:
|
else:
|
||||||
encoder_outputs = None
|
|
||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@@ -471,10 +469,9 @@ class GenerationMixin:
|
|||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
encoder_outputs=encoder_outputs,
|
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
model_specific_kwargs=model_specific_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output = self._generate_no_beam_search(
|
output = self._generate_no_beam_search(
|
||||||
@@ -492,10 +489,9 @@ class GenerationMixin:
|
|||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
batch_size=effective_batch_size,
|
batch_size=effective_batch_size,
|
||||||
encoder_outputs=encoder_outputs,
|
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
model_specific_kwargs=model_specific_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@@ -516,10 +512,9 @@ class GenerationMixin:
|
|||||||
pad_token_id,
|
pad_token_id,
|
||||||
eos_token_id,
|
eos_token_id,
|
||||||
batch_size,
|
batch_size,
|
||||||
encoder_outputs,
|
|
||||||
attention_mask,
|
attention_mask,
|
||||||
use_cache,
|
use_cache,
|
||||||
model_specific_kwargs,
|
model_kwargs,
|
||||||
):
|
):
|
||||||
"""Generate sequences for each example without beam search (num_beams == 1).
|
"""Generate sequences for each example without beam search (num_beams == 1).
|
||||||
All returned sequence are generated independantly.
|
All returned sequence are generated independantly.
|
||||||
@@ -528,15 +523,14 @@ class GenerationMixin:
|
|||||||
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||||
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
||||||
|
|
||||||
past = (encoder_outputs, None) if encoder_outputs is not None else None
|
past = None
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(
|
model_inputs = self.prepare_inputs_for_generation(
|
||||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
|
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = self(**model_inputs)
|
outputs = self(**model_inputs, return_dict=True)
|
||||||
next_token_logits = outputs[0][:, -1, :]
|
next_token_logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
scores = self.postprocess_next_token_scores(
|
scores = self.postprocess_next_token_scores(
|
||||||
scores=next_token_logits,
|
scores=next_token_logits,
|
||||||
@@ -553,8 +547,10 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# if model has past, then set the past variable to speed up decoding
|
# if model has past, then set the past variable to speed up decoding
|
||||||
if self._use_cache(outputs, use_cache):
|
if "past_key_values" in outputs:
|
||||||
past = outputs[1]
|
past = outputs.past_key_values
|
||||||
|
elif "mems" in outputs:
|
||||||
|
past = outputs.mems
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||||
@@ -621,10 +617,9 @@ class GenerationMixin:
|
|||||||
length_penalty,
|
length_penalty,
|
||||||
num_beams,
|
num_beams,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
encoder_outputs,
|
|
||||||
attention_mask,
|
attention_mask,
|
||||||
use_cache,
|
use_cache,
|
||||||
model_specific_kwargs,
|
model_kwargs,
|
||||||
):
|
):
|
||||||
"""Generate sequences for each example with beam search."""
|
"""Generate sequences for each example with beam search."""
|
||||||
|
|
||||||
@@ -643,21 +638,24 @@ class GenerationMixin:
|
|||||||
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||||
|
|
||||||
# cache compute states
|
# cache compute states
|
||||||
past = (encoder_outputs, None) if encoder_outputs is not None else None
|
past = None
|
||||||
|
|
||||||
# done sentences
|
# done sentences
|
||||||
done = [False for _ in range(batch_size)]
|
done = [False for _ in range(batch_size)]
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
model_inputs = self.prepare_inputs_for_generation(
|
model_inputs = self.prepare_inputs_for_generation(
|
||||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
|
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
|
||||||
)
|
)
|
||||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
|
||||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
# if model has past, then set the past variable to speed up decoding
|
# if model has past, then set the past variable to speed up decoding
|
||||||
if self._use_cache(outputs, use_cache):
|
if "past_key_values" in outputs:
|
||||||
past = outputs[1]
|
past = outputs.past_key_values
|
||||||
|
elif "mems" in outputs:
|
||||||
|
past = outputs.mems
|
||||||
|
|
||||||
if self.config.is_encoder_decoder and do_sample is False:
|
if self.config.is_encoder_decoder and do_sample is False:
|
||||||
# TODO (PVP) still a bit hacky here - there might be a better solution
|
# TODO (PVP) still a bit hacky here - there might be a better solution
|
||||||
next_token_logits = self.adjust_logits_during_generation(
|
next_token_logits = self.adjust_logits_during_generation(
|
||||||
|
|||||||
@@ -111,15 +111,15 @@ BART_INPUTS_DOCSTRING = r"""
|
|||||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||||
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
|
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
|
||||||
See diagram 1 in the paper for more info on the default strategy
|
See diagram 1 in the paper for more info on the default strategy
|
||||||
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains pre-computed key and value hidden-states of the attention blocks.
|
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||||
Can be used to speed up decoding.
|
Can be used to speed up decoding.
|
||||||
If ``decoder_past_key_value_states`` are used, the user can optionally input only the last
|
If ``past_key_values`` are used, the user can optionally input only the last
|
||||||
``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape
|
``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape
|
||||||
:obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
|
:obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
If `use_cache` is True, ``decoder_past_key_values`` are returned and can be used to speed up decoding (see
|
If `use_cache` is True, ``past_key_values`` are returned and can be used to speed up decoding (see
|
||||||
``decoder_past_key_values``).
|
``past_key_values``).
|
||||||
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||||
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
||||||
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||||
@@ -502,7 +502,7 @@ class BartDecoder(nn.Module):
|
|||||||
encoder_padding_mask,
|
encoder_padding_mask,
|
||||||
decoder_padding_mask,
|
decoder_padding_mask,
|
||||||
decoder_causal_mask,
|
decoder_causal_mask,
|
||||||
decoder_past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
@@ -519,7 +519,7 @@ class BartDecoder(nn.Module):
|
|||||||
encoder_hidden_states: output from the encoder, used for
|
encoder_hidden_states: output from the encoder, used for
|
||||||
encoder-side attention
|
encoder-side attention
|
||||||
encoder_padding_mask: for ignoring pad tokens
|
encoder_padding_mask: for ignoring pad tokens
|
||||||
decoder_past_key_values (dict or None): dictionary used for storing state during generation
|
past_key_values (dict or None): dictionary used for storing state during generation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BaseModelOutputWithPast or tuple:
|
BaseModelOutputWithPast or tuple:
|
||||||
@@ -530,10 +530,16 @@ class BartDecoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
if "decoder_cached_states" in unused:
|
if "decoder_cached_states" in unused:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
decoder_past_key_values = unused.pop("decoder_cached_states")
|
past_key_values = unused.pop("decoder_cached_states")
|
||||||
|
if "decoder_past_key_values" in unused:
|
||||||
|
warnings.warn(
|
||||||
|
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = unused.pop("decoder_past_key_values")
|
||||||
|
|
||||||
# check attention mask and invert
|
# check attention mask and invert
|
||||||
if encoder_padding_mask is not None:
|
if encoder_padding_mask is not None:
|
||||||
@@ -568,7 +574,7 @@ class BartDecoder(nn.Module):
|
|||||||
if self.training and (dropout_probability < self.layerdrop):
|
if self.training and (dropout_probability < self.layerdrop):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
layer_state = decoder_past_key_values[idx] if decoder_past_key_values is not None else None
|
layer_state = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
x, layer_self_attn, layer_past = decoder_layer(
|
x, layer_self_attn, layer_past = decoder_layer(
|
||||||
x,
|
x,
|
||||||
@@ -594,10 +600,7 @@ class BartDecoder(nn.Module):
|
|||||||
x = x.transpose(0, 1)
|
x = x.transpose(0, 1)
|
||||||
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
||||||
|
|
||||||
if use_cache:
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
|
|
||||||
else:
|
|
||||||
next_cache = None
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
@@ -869,13 +872,19 @@ class BartModel(PretrainedBartModel):
|
|||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
encoder_outputs: Optional[Tuple] = None,
|
encoder_outputs: Optional[Tuple] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
decoder_past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if "decoder_past_key_values" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = kwargs.pop("decoder_past_key_values")
|
||||||
|
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
use_cache = False
|
use_cache = False
|
||||||
@@ -924,7 +933,7 @@ class BartModel(PretrainedBartModel):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
decoder_padding_mask,
|
decoder_padding_mask,
|
||||||
decoder_causal_mask=causal_mask,
|
decoder_causal_mask=causal_mask,
|
||||||
decoder_past_key_values=decoder_past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -936,7 +945,7 @@ class BartModel(PretrainedBartModel):
|
|||||||
|
|
||||||
return Seq2SeqModelOutput(
|
return Seq2SeqModelOutput(
|
||||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||||
decoder_past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||||
@@ -994,7 +1003,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
decoder_past_key_values=None,
|
past_key_values=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
@@ -1037,10 +1046,16 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
labels = unused.pop("lm_labels")
|
labels = unused.pop("lm_labels")
|
||||||
if "decoder_cached_states" in unused:
|
if "decoder_cached_states" in unused:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
decoder_past_key_values = unused.pop("decoder_cached_states")
|
past_key_values = unused.pop("decoder_cached_states")
|
||||||
|
if "decoder_past_key_values" in unused:
|
||||||
|
warnings.warn(
|
||||||
|
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = unused.pop("decoder_past_key_values")
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
@@ -1054,7 +1069,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
decoder_past_key_values=decoder_past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1075,7 +1090,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
return Seq2SeqLMOutput(
|
return Seq2SeqLMOutput(
|
||||||
loss=masked_lm_loss,
|
loss=masked_lm_loss,
|
||||||
logits=lm_logits,
|
logits=lm_logits,
|
||||||
decoder_past_key_values=outputs.decoder_past_key_values,
|
past_key_values=outputs.past_key_values,
|
||||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||||
decoder_attentions=outputs.decoder_attentions,
|
decoder_attentions=outputs.decoder_attentions,
|
||||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||||
@@ -1083,14 +1098,13 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
encoder_attentions=outputs.encoder_attentions,
|
encoder_attentions=outputs.encoder_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
|
def prepare_inputs_for_generation(
|
||||||
assert past is not None, "past has to be defined for encoder_outputs"
|
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
|
||||||
|
):
|
||||||
encoder_outputs, decoder_past_key_values = past
|
|
||||||
return {
|
return {
|
||||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
"decoder_past_key_values": decoder_past_key_values,
|
"past_key_values": past,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
@@ -1109,20 +1123,14 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
((enc_out, enc_mask), decoder_past_key_values) = past
|
|
||||||
reordered_past = []
|
reordered_past = []
|
||||||
for layer_past in decoder_past_key_values:
|
for layer_past in past:
|
||||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||||||
layer_past_new = {
|
layer_past_new = {
|
||||||
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
||||||
}
|
}
|
||||||
reordered_past.append(layer_past_new)
|
reordered_past.append(layer_past_new)
|
||||||
|
return reordered_past
|
||||||
new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx)
|
|
||||||
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
|
|
||||||
|
|
||||||
past = ((new_enc_out, new_enc_mask), reordered_past)
|
|
||||||
return past
|
|
||||||
|
|
||||||
def get_encoder(self):
|
def get_encoder(self):
|
||||||
return self.model.encoder
|
return self.model.encoder
|
||||||
@@ -1208,7 +1216,7 @@ class BartForSequenceClassification(PretrainedBartModel):
|
|||||||
return Seq2SeqSequenceClassifierOutput(
|
return Seq2SeqSequenceClassifierOutput(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
decoder_past_key_values=outputs.decoder_past_key_values,
|
past_key_values=outputs.past_key_values,
|
||||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||||
decoder_attentions=outputs.decoder_attentions,
|
decoder_attentions=outputs.decoder_attentions,
|
||||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||||
@@ -1316,7 +1324,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
|
|||||||
loss=total_loss,
|
loss=total_loss,
|
||||||
start_logits=start_logits,
|
start_logits=start_logits,
|
||||||
end_logits=end_logits,
|
end_logits=end_logits,
|
||||||
decoder_past_key_values=outputs.decoder_past_key_values,
|
past_key_values=outputs.past_key_values,
|
||||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||||
decoder_attentions=outputs.decoder_attentions,
|
decoder_attentions=outputs.decoder_attentions,
|
||||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||||
|
|||||||
@@ -19,13 +19,79 @@ from typing import Optional
|
|||||||
|
|
||||||
from .configuration_encoder_decoder import EncoderDecoderConfig
|
from .configuration_encoder_decoder import EncoderDecoderConfig
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable, replace_return_docstrings
|
||||||
|
from .modeling_outputs import Seq2SeqLMOutput
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_CONFIG_FOR_DOC = "EncoderDecoderConfig"
|
||||||
|
|
||||||
|
ENCODER_DECODER_START_DOCSTRING = r"""
|
||||||
|
This class can be used to inialize a sequence-to-sequnece model with any pretrained autoencoding model as the encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via :meth:`~transformers.AutoModel.from_pretrained` function and the decoder is loaded via :meth:`~transformers.AutoModelForCausalLM.from_pretrained` function.
|
||||||
|
Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream generative task, *i.e.* summarization.
|
||||||
|
|
||||||
|
The effectiveness of initializing sequence-to-sequence models with pre-trained checkpoints for sequence generation tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks <https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
|
||||||
|
Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
|
||||||
|
|
||||||
|
After such an Encoder Decoder model has been trained / fine-tuned, it can be saved / loaded just like any other models (see Examples for more information).
|
||||||
|
|
||||||
|
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#module>`__ sub-class. Use it as a
|
||||||
|
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||||
|
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary for the encoder.
|
||||||
|
Indices can be obtained using :class:`~transformers.PretrainedTokenizer`.
|
||||||
|
See :meth:`~transformers.PreTrainedTokenizer.encode` and
|
||||||
|
:meth:`~transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
|
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Mask to avoid performing attention on padding token indices for the encoder.
|
||||||
|
Mask values selected in ``[0, 1]``:
|
||||||
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||||
|
encoder_outputs (:obj:`tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
|
||||||
|
This tuple must consist of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`)
|
||||||
|
`last_hidden_state` (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`) is a tensor of hidden-states at the output of the last layer of the encoder.
|
||||||
|
Used in the cross-attention of the decoder.
|
||||||
|
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Provide for sequence to sequence training to the decoder.
|
||||||
|
Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
|
||||||
|
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||||
|
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
|
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||||
|
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the masked language modeling loss for the decoder.
|
||||||
|
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||||
|
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||||
|
in ``[0, ..., config.vocab_size]``
|
||||||
|
return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||||
|
If set to ``True``, the model will return a :class:`~transformers.file_utils.Seq2SeqLMOutput` instead of a
|
||||||
|
plain tuple.
|
||||||
|
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
|
||||||
|
- Without a prefix which will be input as ``**encoder_kwargs`` for the encoder forward function.
|
||||||
|
- With a `decoder_` prefix which will be input as ``**decoder_kwargs`` for the decoder forward function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
|
||||||
class EncoderDecoderModel(PreTrainedModel):
|
class EncoderDecoderModel(PreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
:class:`~transformers.EncoderDecoder` is a generic model class that will be
|
:class:`~transformers.EncoderDecoder` is a generic model class that will be
|
||||||
@@ -206,6 +272,8 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
||||||
return cls(encoder=encoder, decoder=decoder, config=config)
|
return cls(encoder=encoder, decoder=decoder, config=config)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_callable(ENCODER_DECODER_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
@@ -216,47 +284,11 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
|
return_dict=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
"""
|
Returns:
|
||||||
Args:
|
|
||||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
|
||||||
Indices of input sequence tokens in the vocabulary for the encoder.
|
|
||||||
Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
|
|
||||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
|
||||||
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
|
||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
|
||||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
||||||
than the model's internal embedding lookup matrix.
|
|
||||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
|
||||||
Mask to avoid performing attention on padding token indices for the encoder.
|
|
||||||
Mask values selected in ``[0, 1]``:
|
|
||||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
||||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
|
|
||||||
Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
|
|
||||||
`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
|
|
||||||
Used in the cross-attention of the decoder.
|
|
||||||
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
|
||||||
Provide for sequence to sequence training to the decoder.
|
|
||||||
Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
|
|
||||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
|
||||||
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
|
||||||
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
|
||||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
|
||||||
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
|
||||||
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
|
|
||||||
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
|
|
||||||
than the model's internal embedding lookup matrix.
|
|
||||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
|
||||||
Labels for computing the masked language modeling loss for the decoder.
|
|
||||||
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
|
||||||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
|
||||||
in ``[0, ..., config.vocab_size]``
|
|
||||||
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
|
|
||||||
- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
|
|
||||||
- With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function.
|
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
@@ -264,19 +296,25 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
>>> import torch
|
>>> import torch
|
||||||
|
|
||||||
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
|
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
|
||||||
|
|
||||||
>>> # forward
|
>>> # forward
|
||||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
|
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
|
||||||
|
|
||||||
>>> # training
|
>>> # training
|
||||||
>>> loss, outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)[:2]
|
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids, return_dict=True)
|
||||||
|
>>> loss, logits = outputs.loss, outputs.logits
|
||||||
|
|
||||||
|
>>> # save and load from pretrained
|
||||||
|
>>> model.save_pretrained("bert2bert")
|
||||||
|
>>> model = EncoderDecoderModel.from_pretrained("bert2bert")
|
||||||
|
|
||||||
>>> # generation
|
>>> # generation
|
||||||
>>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)
|
>>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
|
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
|
||||||
|
|
||||||
@@ -289,7 +327,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
return_dict=False,
|
return_dict=return_dict,
|
||||||
**kwargs_encoder,
|
**kwargs_encoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -303,23 +341,28 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
encoder_hidden_states=hidden_states,
|
encoder_hidden_states=hidden_states,
|
||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
return_dict=False,
|
return_dict=return_dict,
|
||||||
**kwargs_decoder,
|
**kwargs_decoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(PVP): currently it is not possible to use `past`
|
# TODO(PVP): currently it is not possible to use `past`
|
||||||
# with the encoder/decoder framework -> should be implemented
|
if not return_dict:
|
||||||
|
return decoder_outputs + encoder_outputs
|
||||||
|
|
||||||
|
return Seq2SeqLMOutput(
|
||||||
|
loss=decoder_outputs.loss,
|
||||||
|
logits=decoder_outputs.logits,
|
||||||
|
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works
|
||||||
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||||
|
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||||
|
encoder_attentions=encoder_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
return decoder_outputs + encoder_outputs
|
return decoder_outputs + encoder_outputs
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, encoder_outputs, **kwargs):
|
||||||
assert past is not None, "past has to be defined for encoder_outputs"
|
|
||||||
|
|
||||||
# first step
|
|
||||||
if type(past) is tuple:
|
|
||||||
encoder_outputs, _ = past
|
|
||||||
else:
|
|
||||||
encoder_outputs = (past,)
|
|
||||||
|
|
||||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
|
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
|
||||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||||
input_dict = {
|
input_dict = {
|
||||||
@@ -335,7 +378,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
input_dict["decoder_use_cache"] = decoder_inputs["use_cache"]
|
input_dict["decoder_use_cache"] = decoder_inputs["use_cache"]
|
||||||
|
|
||||||
if "past_key_values" in decoder_inputs:
|
if "past_key_values" in decoder_inputs:
|
||||||
input_dict["decoder_past_key_values"] = decoder_inputs["past_key_values"]
|
input_dict["past_key_values"] = decoder_inputs["past_key_values"]
|
||||||
|
|
||||||
return input_dict
|
return input_dict
|
||||||
|
|
||||||
|
|||||||
@@ -353,11 +353,11 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
|
|||||||
Base class for outputs of models predicting if two sentences are consecutive or not.
|
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
|
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
|
||||||
Language modeling loss.
|
Language modeling loss.
|
||||||
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
|
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
|
||||||
Multiple choice classification loss.
|
Multiple choice classification loss.
|
||||||
lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
||||||
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
||||||
@@ -380,9 +380,9 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
|
|||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lm_loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
mc_loss: Optional[torch.FloatTensor] = None
|
mc_loss: Optional[torch.FloatTensor] = None
|
||||||
lm_logits: torch.FloatTensor = None
|
logits: torch.FloatTensor = None
|
||||||
mc_logits: torch.FloatTensor = None
|
mc_logits: torch.FloatTensor = None
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
@@ -777,6 +777,17 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||||
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
|
if past:
|
||||||
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": kwargs.get("use_cache"),
|
||||||
|
}
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -893,9 +904,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||||
|
|
||||||
return GPT2DoubleHeadsModelOutput(
|
return GPT2DoubleHeadsModelOutput(
|
||||||
lm_loss=lm_loss,
|
loss=lm_loss,
|
||||||
mc_loss=mc_loss,
|
mc_loss=mc_loss,
|
||||||
lm_logits=lm_logits,
|
logits=lm_logits,
|
||||||
mc_logits=mc_logits,
|
mc_logits=mc_logits,
|
||||||
past_key_values=transformer_outputs.past_key_values,
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
|
|||||||
@@ -300,11 +300,11 @@ class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
|
|||||||
Base class for outputs of models predicting if two sentences are consecutive or not.
|
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
|
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
|
||||||
Language modeling loss.
|
Language modeling loss.
|
||||||
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
|
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
|
||||||
Multiple choice classification loss.
|
Multiple choice classification loss.
|
||||||
lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
||||||
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
||||||
@@ -321,9 +321,9 @@ class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
|
|||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lm_loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
mc_loss: Optional[torch.FloatTensor] = None
|
mc_loss: Optional[torch.FloatTensor] = None
|
||||||
lm_logits: torch.FloatTensor = None
|
logits: torch.FloatTensor = None
|
||||||
mc_logits: torch.FloatTensor = None
|
mc_logits: torch.FloatTensor = None
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
@@ -713,9 +713,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||||
|
|
||||||
return OpenAIGPTDoubleHeadsModelOutput(
|
return OpenAIGPTDoubleHeadsModelOutput(
|
||||||
lm_loss=lm_loss,
|
loss=lm_loss,
|
||||||
mc_loss=mc_loss,
|
mc_loss=mc_loss,
|
||||||
lm_logits=lm_logits,
|
logits=lm_logits,
|
||||||
mc_logits=mc_logits,
|
mc_logits=mc_logits,
|
||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
|
|||||||
@@ -109,13 +109,13 @@ class Seq2SeqModelOutput(ModelOutput):
|
|||||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
||||||
|
|
||||||
If ``decoder_past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
|
If ``past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
|
||||||
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
|
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||||
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
@@ -143,7 +143,7 @@ class Seq2SeqModelOutput(ModelOutput):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
last_hidden_state: torch.FloatTensor
|
last_hidden_state: torch.FloatTensor
|
||||||
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
|
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||||
@@ -255,12 +255,12 @@ class Seq2SeqLMOutput(ModelOutput):
|
|||||||
Languaged modeling loss.
|
Languaged modeling loss.
|
||||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
|
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||||
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
@@ -289,7 +289,7 @@ class Seq2SeqLMOutput(ModelOutput):
|
|||||||
|
|
||||||
loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
logits: torch.FloatTensor = None
|
logits: torch.FloatTensor = None
|
||||||
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
|
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||||
@@ -365,12 +365,12 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
|
|||||||
Classification (or regression if config.num_labels==1) loss.
|
Classification (or regression if config.num_labels==1) loss.
|
||||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||||
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
|
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||||
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
@@ -399,7 +399,7 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
|
|||||||
|
|
||||||
loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
logits: torch.FloatTensor = None
|
logits: torch.FloatTensor = None
|
||||||
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
|
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||||
@@ -511,12 +511,12 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
|||||||
Span-start scores (before SoftMax).
|
Span-start scores (before SoftMax).
|
||||||
end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||||
Span-end scores (before SoftMax).
|
Span-end scores (before SoftMax).
|
||||||
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
|
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||||
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
@@ -546,7 +546,7 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
|||||||
loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
start_logits: torch.FloatTensor = None
|
start_logits: torch.FloatTensor = None
|
||||||
end_logits: torch.FloatTensor = None
|
end_logits: torch.FloatTensor = None
|
||||||
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
|
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||||
|
|||||||
@@ -838,27 +838,27 @@ T5_INPUTS_DOCSTRING = r"""
|
|||||||
Used in the cross-attention of the decoder.
|
Used in the cross-attention of the decoder.
|
||||||
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
|
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
|
||||||
If `decoder_past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_values`).
|
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
|
||||||
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
|
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
|
||||||
`T5 Training <./t5.html#training>`__. If decoder_input_ids and decoder_inputs_embeds are both None,
|
`T5 Training <./t5.html#training>`__. If decoder_input_ids and decoder_inputs_embeds are both None,
|
||||||
decoder_input_ids takes the value of input_ids.
|
decoder_input_ids takes the value of input_ids.
|
||||||
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
||||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||||
decoder_past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains pre-computed key and value hidden-states of the attention blocks.
|
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||||
Can be used to speed up decoding.
|
Can be used to speed up decoding.
|
||||||
If `decoder_past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
|
||||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
If `use_cache` is True, `decoder_past_key_values` are returned and can be used to speed up decoding (see `decoder_past_key_values`).
|
If `use_cache` is True, `past_key_values` are returned and can be used to speed up decoding (see `past_key_values`).
|
||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
than the model's internal embedding lookup matrix.
|
than the model's internal embedding lookup matrix.
|
||||||
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||||
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
|
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
|
||||||
If `decoder_past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_values`).
|
If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`).
|
||||||
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
|
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
|
||||||
than the model's internal embedding lookup matrix. If decoder_input_ids and decoder_inputs_embeds are both None,
|
than the model's internal embedding lookup matrix. If decoder_input_ids and decoder_inputs_embeds are both None,
|
||||||
decoder_inputs_embeds takes the value of inputs_embeds.
|
decoder_inputs_embeds takes the value of inputs_embeds.
|
||||||
@@ -928,7 +928,7 @@ class T5Model(T5PreTrainedModel):
|
|||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
decoder_past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds=None,
|
||||||
@@ -955,10 +955,16 @@ class T5Model(T5PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
if "decoder_past_key_value_states" in kwargs:
|
if "decoder_past_key_value_states" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
decoder_past_key_values = kwargs.pop("decoder_past_key_value_states")
|
past_key_values = kwargs.pop("decoder_past_key_value_states")
|
||||||
|
if "decoder_past_key_values" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = kwargs.pop("decoder_past_key_values")
|
||||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
@@ -992,7 +998,7 @@ class T5Model(T5PreTrainedModel):
|
|||||||
|
|
||||||
# If decoding with past key value states, only the last tokens
|
# If decoding with past key value states, only the last tokens
|
||||||
# should be given as an input
|
# should be given as an input
|
||||||
if decoder_past_key_values is not None:
|
if past_key_values is not None:
|
||||||
if decoder_input_ids is not None:
|
if decoder_input_ids is not None:
|
||||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||||
if decoder_inputs_embeds is not None:
|
if decoder_inputs_embeds is not None:
|
||||||
@@ -1003,7 +1009,7 @@ class T5Model(T5PreTrainedModel):
|
|||||||
input_ids=decoder_input_ids,
|
input_ids=decoder_input_ids,
|
||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
past_key_value_states=decoder_past_key_values,
|
past_key_value_states=past_key_values,
|
||||||
encoder_hidden_states=hidden_states,
|
encoder_hidden_states=hidden_states,
|
||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
@@ -1013,15 +1019,12 @@ class T5Model(T5PreTrainedModel):
|
|||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
past = (encoder_outputs, decoder_outputs[1]) if use_cache is True else None
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
if past is not None:
|
|
||||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
|
||||||
return decoder_outputs + encoder_outputs
|
return decoder_outputs + encoder_outputs
|
||||||
|
|
||||||
return Seq2SeqModelOutput(
|
return Seq2SeqModelOutput(
|
||||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||||
decoder_past_key_values=past,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||||
@@ -1080,7 +1083,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
decoder_past_key_values=None,
|
past_key_values=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1127,10 +1130,16 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
labels = kwargs.pop("lm_labels")
|
labels = kwargs.pop("lm_labels")
|
||||||
if "decoder_past_key_value_states" in kwargs:
|
if "decoder_past_key_value_states" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
decoder_past_key_values = kwargs.pop("decoder_past_key_value_states")
|
past_key_values = kwargs.pop("decoder_past_key_value_states")
|
||||||
|
if "decoder_past_key_values" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = kwargs.pop("decoder_past_key_values")
|
||||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
@@ -1163,7 +1172,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
|
|
||||||
# If decoding with past key value states, only the last tokens
|
# If decoding with past key value states, only the last tokens
|
||||||
# should be given as an input
|
# should be given as an input
|
||||||
if decoder_past_key_values is not None:
|
if past_key_values is not None:
|
||||||
assert labels is None, "Decoder should not use cached key value states when training."
|
assert labels is None, "Decoder should not use cached key value states when training."
|
||||||
if decoder_input_ids is not None:
|
if decoder_input_ids is not None:
|
||||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||||
@@ -1175,7 +1184,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
input_ids=decoder_input_ids,
|
input_ids=decoder_input_ids,
|
||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
past_key_value_states=decoder_past_key_values,
|
past_key_value_states=past_key_values,
|
||||||
encoder_hidden_states=hidden_states,
|
encoder_hidden_states=hidden_states,
|
||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
@@ -1197,17 +1206,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
||||||
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
||||||
|
|
||||||
past = (encoder_outputs, decoder_outputs[1]) if use_cache is True else None
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
if past is not None:
|
|
||||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
|
||||||
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return Seq2SeqLMOutput(
|
return Seq2SeqLMOutput(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=lm_logits,
|
logits=lm_logits,
|
||||||
decoder_past_key_values=past,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||||
@@ -1215,14 +1221,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
encoder_attentions=encoder_outputs.attentions,
|
encoder_attentions=encoder_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs):
|
||||||
assert past is not None, "past has to be defined for encoder_outputs"
|
|
||||||
|
|
||||||
encoder_outputs, decoder_past_key_values = past
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"decoder_input_ids": input_ids,
|
"decoder_input_ids": input_ids,
|
||||||
"decoder_past_key_values": decoder_past_key_values,
|
"past_key_values": past,
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
@@ -1231,14 +1233,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
def _reorder_cache(self, past, beam_idx):
|
def _reorder_cache(self, past, beam_idx):
|
||||||
# if decoder past is not included in output
|
# if decoder past is not included in output
|
||||||
# speedy decoding is disabled and no need to reorder
|
# speedy decoding is disabled and no need to reorder
|
||||||
if past[1] is None:
|
if past is None:
|
||||||
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
|
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
|
||||||
return past
|
return past
|
||||||
|
|
||||||
decoder_past = past[1]
|
|
||||||
past = (past[0],)
|
|
||||||
reordered_decoder_past = ()
|
reordered_decoder_past = ()
|
||||||
for layer_past_states in decoder_past:
|
for layer_past_states in past:
|
||||||
# get the correct batch idx from layer past batch dim
|
# get the correct batch idx from layer past batch dim
|
||||||
# batch dim of `past` is at 2nd position
|
# batch dim of `past` is at 2nd position
|
||||||
reordered_layer_past_states = ()
|
reordered_layer_past_states = ()
|
||||||
@@ -1252,4 +1252,4 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
assert len(reordered_layer_past_states) == len(layer_past_states)
|
assert len(reordered_layer_past_states) == len(layer_past_states)
|
||||||
|
|
||||||
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
|
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
|
||||||
return past + (reordered_decoder_past,)
|
return reordered_decoder_past
|
||||||
|
|||||||
@@ -431,7 +431,7 @@ class TFGPT2DoubleHeadsModelOutput(ModelOutput):
|
|||||||
Base class for outputs of models predicting if two sentences are consecutive or not.
|
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lm_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
mc_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`):
|
mc_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`):
|
||||||
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
||||||
@@ -454,7 +454,7 @@ class TFGPT2DoubleHeadsModelOutput(ModelOutput):
|
|||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lm_logits: tf.Tensor = None
|
logits: tf.Tensor = None
|
||||||
mc_logits: tf.Tensor = None
|
mc_logits: tf.Tensor = None
|
||||||
past_key_values: Optional[List[tf.Tensor]] = None
|
past_key_values: Optional[List[tf.Tensor]] = None
|
||||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
@@ -794,7 +794,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
|||||||
return (lm_logits, mc_logits) + transformer_outputs[1:]
|
return (lm_logits, mc_logits) + transformer_outputs[1:]
|
||||||
|
|
||||||
return TFGPT2DoubleHeadsModelOutput(
|
return TFGPT2DoubleHeadsModelOutput(
|
||||||
lm_logits=lm_logits,
|
logits=lm_logits,
|
||||||
mc_logits=mc_logits,
|
mc_logits=mc_logits,
|
||||||
past_key_values=transformer_outputs.past_key_values,
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
|
|||||||
@@ -394,7 +394,7 @@ class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput):
|
|||||||
Base class for outputs of models predicting if two sentences are consecutive or not.
|
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lm_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
mc_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`):
|
mc_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`):
|
||||||
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
||||||
@@ -411,7 +411,7 @@ class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput):
|
|||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
lm_logits: tf.Tensor = None
|
logits: tf.Tensor = None
|
||||||
mc_logits: tf.Tensor = None
|
mc_logits: tf.Tensor = None
|
||||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
@@ -719,7 +719,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
|||||||
return (lm_logits, mc_logits) + transformer_outputs[1:]
|
return (lm_logits, mc_logits) + transformer_outputs[1:]
|
||||||
|
|
||||||
return TFOpenAIGPTDoubleHeadsModelOutput(
|
return TFOpenAIGPTDoubleHeadsModelOutput(
|
||||||
lm_logits=lm_logits,
|
logits=lm_logits,
|
||||||
mc_logits=mc_logits,
|
mc_logits=mc_logits,
|
||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
|
|||||||
@@ -113,13 +113,13 @@ class TFSeq2SeqModelOutput(ModelOutput):
|
|||||||
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
||||||
|
|
||||||
If ``decoder_past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
|
If ``past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
|
||||||
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
|
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||||
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
@@ -147,7 +147,7 @@ class TFSeq2SeqModelOutput(ModelOutput):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
last_hidden_state: tf.Tensor = None
|
last_hidden_state: tf.Tensor = None
|
||||||
decoder_past_key_values: Optional[List[tf.Tensor]] = None
|
past_key_values: Optional[List[tf.Tensor]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
||||||
@@ -259,12 +259,12 @@ class TFSeq2SeqLMOutput(ModelOutput):
|
|||||||
Languaged modeling loss.
|
Languaged modeling loss.
|
||||||
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
|
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||||
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
@@ -293,7 +293,7 @@ class TFSeq2SeqLMOutput(ModelOutput):
|
|||||||
|
|
||||||
loss: Optional[tf.Tensor] = None
|
loss: Optional[tf.Tensor] = None
|
||||||
logits: tf.Tensor = None
|
logits: tf.Tensor = None
|
||||||
decoder_past_key_values: Optional[List[tf.Tensor]] = None
|
past_key_values: Optional[List[tf.Tensor]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
||||||
@@ -366,12 +366,12 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
|
|||||||
Classification (or regression if config.num_labels==1) loss.
|
Classification (or regression if config.num_labels==1) loss.
|
||||||
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
|
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||||
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
|
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||||
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
@@ -400,7 +400,7 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
|
|||||||
|
|
||||||
loss: Optional[tf.Tensor] = None
|
loss: Optional[tf.Tensor] = None
|
||||||
logits: tf.Tensor = None
|
logits: tf.Tensor = None
|
||||||
decoder_past_key_values: Optional[List[tf.Tensor]] = None
|
past_key_values: Optional[List[tf.Tensor]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
||||||
@@ -512,12 +512,12 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
|||||||
Span-start scores (before SoftMax).
|
Span-start scores (before SoftMax).
|
||||||
end_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
end_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||||
Span-end scores (before SoftMax).
|
Span-end scores (before SoftMax).
|
||||||
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
|
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
|
||||||
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
|
||||||
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
|
used (see ``past_key_values`` input) to speed up sequential decoding.
|
||||||
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
@@ -547,7 +547,7 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
|||||||
loss: Optional[tf.Tensor] = None
|
loss: Optional[tf.Tensor] = None
|
||||||
start_logits: tf.Tensor = None
|
start_logits: tf.Tensor = None
|
||||||
end_logits: tf.Tensor = None
|
end_logits: tf.Tensor = None
|
||||||
decoder_past_key_values: Optional[List[tf.Tensor]] = None
|
past_key_values: Optional[List[tf.Tensor]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
||||||
|
|||||||
@@ -437,15 +437,15 @@ class TFT5Block(tf.keras.layers.Layer):
|
|||||||
):
|
):
|
||||||
|
|
||||||
if past_key_value_state is not None:
|
if past_key_value_state is not None:
|
||||||
assert self.is_decoder, "Only decoder can use `past_key_value_states`"
|
assert self.is_decoder, "Only decoder can use `past_key_values`"
|
||||||
expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4
|
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
|
||||||
|
|
||||||
error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format(
|
error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format(
|
||||||
expected_num_past_key_value_states,
|
expected_num_past_key_values,
|
||||||
"2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "",
|
"2 (past / key) for cross attention" if expected_num_past_key_values == 4 else "",
|
||||||
len(past_key_value_state),
|
len(past_key_value_state),
|
||||||
)
|
)
|
||||||
assert len(past_key_value_state) == expected_num_past_key_value_states, error_message
|
assert len(past_key_value_state) == expected_num_past_key_values, error_message
|
||||||
|
|
||||||
self_attn_past_key_value_state = past_key_value_state[:2]
|
self_attn_past_key_value_state = past_key_value_state[:2]
|
||||||
cross_attn_past_key_value_state = past_key_value_state[2:]
|
cross_attn_past_key_value_state = past_key_value_state[2:]
|
||||||
@@ -586,11 +586,12 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
past_key_value_states=None,
|
past_key_values=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
training=False,
|
training=False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
@@ -599,7 +600,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
encoder_attention_mask = inputs[3] if len(inputs) > 3 else encoder_attention_mask
|
encoder_attention_mask = inputs[3] if len(inputs) > 3 else encoder_attention_mask
|
||||||
inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds
|
inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds
|
||||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||||
past_key_value_states = inputs[6] if len(inputs) > 6 else past_key_value_states
|
past_key_values = inputs[6] if len(inputs) > 6 else past_key_values
|
||||||
use_cache = inputs[7] if len(inputs) > 7 else use_cache
|
use_cache = inputs[7] if len(inputs) > 7 else use_cache
|
||||||
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
|
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
|
||||||
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
|
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
|
||||||
@@ -611,13 +612,26 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask)
|
encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask)
|
||||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
head_mask = inputs.get("head_mask", head_mask)
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
past_key_value_states = inputs.get("past_key_value_states", past_key_value_states)
|
past_key_values = inputs.get("past_key_values", past_key_values)
|
||||||
use_cache = inputs.get("use_cache", use_cache)
|
use_cache = inputs.get("use_cache", use_cache)
|
||||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||||
assert len(inputs) <= 10, "Too many inputs."
|
assert len(inputs) <= 10, "Too many inputs."
|
||||||
|
|
||||||
|
if "past_key_value_states" in inputs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = inputs.pop("past_key_value_states")
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
if "past_key_value_states" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = kwargs.pop("past_key_value_states")
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||||
@@ -639,13 +653,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
|
|
||||||
if past_key_value_states is not None:
|
if past_key_values is not None:
|
||||||
assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format(
|
assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format(
|
||||||
input_shape, (batch_size, 1)
|
input_shape, (batch_size, 1)
|
||||||
)
|
)
|
||||||
# required mask seq length can be calculated via length of past
|
# required mask seq length can be calculated via length of past
|
||||||
# key value states and seq_length = 1 for the last token
|
# key value states and seq_length = 1 for the last token
|
||||||
mask_seq_length = shape_list(past_key_value_states[0][0])[2] + seq_length
|
mask_seq_length = shape_list(past_key_values[0][0])[2] + seq_length
|
||||||
else:
|
else:
|
||||||
mask_seq_length = seq_length
|
mask_seq_length = seq_length
|
||||||
|
|
||||||
@@ -655,9 +669,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
encoder_seq_length = shape_list(encoder_hidden_states)[1]
|
encoder_seq_length = shape_list(encoder_hidden_states)[1]
|
||||||
encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1)
|
encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1)
|
||||||
|
|
||||||
# initialize past_key_value_states with `None` if past does not exist
|
# initialize past_key_values with `None` if past does not exist
|
||||||
if past_key_value_states is None:
|
if past_key_values is None:
|
||||||
past_key_value_states = [None] * len(self.block)
|
past_key_values = [None] * len(self.block)
|
||||||
|
|
||||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
@@ -677,7 +691,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
causal_mask = tf.cast(causal_mask, dtype=tf.float32)
|
causal_mask = tf.cast(causal_mask, dtype=tf.float32)
|
||||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||||
if past_key_value_states[0] is not None:
|
if past_key_values[0] is not None:
|
||||||
extended_attention_mask = extended_attention_mask[:, :, -1:, :]
|
extended_attention_mask = extended_attention_mask[:, :, -1:, :]
|
||||||
else:
|
else:
|
||||||
extended_attention_mask = attention_mask[:, None, None, :]
|
extended_attention_mask = attention_mask[:, None, None, :]
|
||||||
@@ -726,7 +740,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
hidden_states = self.dropout(inputs_embeds, training=training)
|
hidden_states = self.dropout(inputs_embeds, training=training)
|
||||||
|
|
||||||
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
|
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_values)):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
@@ -878,7 +892,7 @@ T5_INPUTS_DOCSTRING = r"""
|
|||||||
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
|
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
|
||||||
If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`).
|
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
|
||||||
attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
Mask to avoid performing attention on padding token indices.
|
Mask to avoid performing attention on padding token indices.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
@@ -889,13 +903,13 @@ T5_INPUTS_DOCSTRING = r"""
|
|||||||
Used in the cross-attention of the decoder.
|
Used in the cross-attention of the decoder.
|
||||||
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
||||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||||
decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains pre-computed key and value hidden-states of the attention blocks.
|
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||||
Can be used to speed up decoding.
|
Can be used to speed up decoding.
|
||||||
If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids`
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
|
||||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`).
|
If `use_cache` is True, `past_key_values` are returned and can be used to speed up decoding (see `past_key_values`).
|
||||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||||
Optionally, instead of passing :obj:`inputs` you can choose to directly pass an embedded representation.
|
Optionally, instead of passing :obj:`inputs` you can choose to directly pass an embedded representation.
|
||||||
This is useful if you want more control over how to convert `inputs` indices into associated vectors
|
This is useful if you want more control over how to convert `inputs` indices into associated vectors
|
||||||
@@ -969,7 +983,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_past_key_value_states=None,
|
past_key_values=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds=None,
|
||||||
@@ -978,6 +992,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
training=False,
|
training=False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@@ -999,7 +1014,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
|
encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
|
||||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states
|
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
|
||||||
decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
|
decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
|
||||||
decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
|
decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
|
||||||
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
|
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
|
||||||
@@ -1017,7 +1032,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
|
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
|
||||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
head_mask = inputs.get("head_mask", head_mask)
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states)
|
past_key_values = inputs.get("past_key_values", past_key_values)
|
||||||
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
|
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
|
||||||
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
|
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
|
||||||
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
|
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
|
||||||
@@ -1026,9 +1041,23 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||||
return_dict = inputs.get("return_dict", return_dict)
|
return_dict = inputs.get("return_dict", return_dict)
|
||||||
assert len(inputs) <= 13, "Too many inputs."
|
assert len(inputs) <= 13, "Too many inputs."
|
||||||
|
|
||||||
|
if "past_key_value_states" in inputs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = inputs.pop("past_key_value_states")
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
|
if "past_key_value_states" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = kwargs.pop("past_key_value_states")
|
||||||
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||||
|
|
||||||
@@ -1054,7 +1083,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
|
|
||||||
# If decoding with past key value states, only the last tokens
|
# If decoding with past key value states, only the last tokens
|
||||||
# should be given as an input
|
# should be given as an input
|
||||||
if decoder_past_key_value_states is not None:
|
if past_key_values is not None:
|
||||||
if decoder_input_ids is not None:
|
if decoder_input_ids is not None:
|
||||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||||
if decoder_inputs_embeds is not None:
|
if decoder_inputs_embeds is not None:
|
||||||
@@ -1069,7 +1098,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
decoder_inputs_embeds,
|
decoder_inputs_embeds,
|
||||||
head_mask,
|
head_mask,
|
||||||
decoder_past_key_value_states,
|
past_key_values,
|
||||||
use_cache,
|
use_cache,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
@@ -1103,7 +1132,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
|
|
||||||
return TFSeq2SeqModelOutput(
|
return TFSeq2SeqModelOutput(
|
||||||
last_hidden_state=decoder_outputs[0],
|
last_hidden_state=decoder_outputs[0],
|
||||||
decoder_past_key_values=past,
|
past_key_values=past,
|
||||||
decoder_hidden_states=decoder_outputs[2],
|
decoder_hidden_states=decoder_outputs[2],
|
||||||
decoder_attentions=decoder_outputs[3],
|
decoder_attentions=decoder_outputs[3],
|
||||||
encoder_last_hidden_state=encoder_outputs[0],
|
encoder_last_hidden_state=encoder_outputs[0],
|
||||||
@@ -1164,7 +1193,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_past_key_value_states=None,
|
past_key_values=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds=None,
|
||||||
@@ -1174,6 +1203,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
return_dict=None,
|
return_dict=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
@@ -1204,7 +1234,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
|
encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
|
||||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states
|
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
|
||||||
decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
|
decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
|
||||||
decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
|
decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
|
||||||
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
|
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
|
||||||
@@ -1223,7 +1253,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
|
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
|
||||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
head_mask = inputs.get("head_mask", head_mask)
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states)
|
past_key_values = inputs.get("past_key_values", past_key_values)
|
||||||
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
|
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
|
||||||
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
|
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
|
||||||
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
|
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
|
||||||
@@ -1233,9 +1263,23 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
return_dict = inputs.get("return_dict", return_dict)
|
return_dict = inputs.get("return_dict", return_dict)
|
||||||
labels = inputs.get("labels", labels)
|
labels = inputs.get("labels", labels)
|
||||||
assert len(inputs) <= 14, "Too many inputs."
|
assert len(inputs) <= 14, "Too many inputs."
|
||||||
|
|
||||||
|
if "past_key_value_states" in inputs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = inputs.pop("past_key_value_states")
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
|
if "past_key_value_states" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = kwargs.pop("past_key_value_states")
|
||||||
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||||
|
|
||||||
@@ -1266,7 +1310,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
|
|
||||||
# If decoding with past key value states, only the last tokens
|
# If decoding with past key value states, only the last tokens
|
||||||
# should be given as an input
|
# should be given as an input
|
||||||
if decoder_past_key_value_states is not None:
|
if past_key_values is not None:
|
||||||
if decoder_input_ids is not None:
|
if decoder_input_ids is not None:
|
||||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||||
if decoder_inputs_embeds is not None:
|
if decoder_inputs_embeds is not None:
|
||||||
@@ -1281,7 +1325,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
decoder_inputs_embeds,
|
decoder_inputs_embeds,
|
||||||
head_mask,
|
head_mask,
|
||||||
decoder_past_key_value_states,
|
past_key_values,
|
||||||
use_cache,
|
use_cache,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
@@ -1324,7 +1368,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
return TFSeq2SeqLMOutput(
|
return TFSeq2SeqLMOutput(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
decoder_past_key_values=past,
|
past_key_values=past,
|
||||||
decoder_hidden_states=decoder_outputs[2],
|
decoder_hidden_states=decoder_outputs[2],
|
||||||
decoder_attentions=decoder_outputs[3],
|
decoder_attentions=decoder_outputs[3],
|
||||||
encoder_last_hidden_state=encoder_outputs[0],
|
encoder_last_hidden_state=encoder_outputs[0],
|
||||||
@@ -1337,14 +1381,14 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
|
|
||||||
# first step
|
# first step
|
||||||
if len(past) < 2:
|
if len(past) < 2:
|
||||||
encoder_outputs, decoder_past_key_value_states = past, None
|
encoder_outputs, past_key_values = past, None
|
||||||
else:
|
else:
|
||||||
encoder_outputs, decoder_past_key_value_states = past[0], past[1]
|
encoder_outputs, past_key_values = past[0], past[1]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
|
"inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
|
||||||
"decoder_input_ids": inputs, # inputs are the decoder_input_ids
|
"decoder_input_ids": inputs, # inputs are the decoder_input_ids
|
||||||
"decoder_past_key_value_states": decoder_past_key_value_states,
|
"past_key_values": past_key_values,
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
|
|||||||
@@ -661,6 +661,15 @@ class TransfoXLLMHeadModelOutput(ModelOutput):
|
|||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logits(self):
|
||||||
|
# prediciton scores are the output of the adaptive softmax, see
|
||||||
|
# the file `modeling_transfo_xl_utilities`. Since the adaptive
|
||||||
|
# softmax returns the log softmax value, `self.prediciton_scores`
|
||||||
|
# are strictly speaking not exactly `logits`, but behave the same
|
||||||
|
# way logits do.
|
||||||
|
return self.prediction_scores
|
||||||
|
|
||||||
|
|
||||||
TRANSFO_XL_START_DOCSTRING = r"""
|
TRANSFO_XL_START_DOCSTRING = r"""
|
||||||
|
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
BertLMHeadModel,
|
BertLMHeadModel,
|
||||||
BertModel,
|
BertModel,
|
||||||
|
BertTokenizer,
|
||||||
EncoderDecoderConfig,
|
EncoderDecoderConfig,
|
||||||
EncoderDecoderModel,
|
EncoderDecoderModel,
|
||||||
GPT2LMHeadModel,
|
GPT2LMHeadModel,
|
||||||
@@ -128,10 +129,11 @@ class EncoderDecoderMixin:
|
|||||||
decoder_config,
|
decoder_config,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
decoder_attention_mask,
|
decoder_attention_mask,
|
||||||
|
return_dict,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
|
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict}
|
||||||
enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
|
enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
|
||||||
enc_dec_model.to(torch_device)
|
enc_dec_model.to(torch_device)
|
||||||
outputs_encoder_decoder = enc_dec_model(
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
@@ -361,7 +363,11 @@ class EncoderDecoderMixin:
|
|||||||
|
|
||||||
def test_encoder_decoder_model_from_pretrained(self):
|
def test_encoder_decoder_model_from_pretrained(self):
|
||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict)
|
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=False)
|
||||||
|
|
||||||
|
def test_encoder_decoder_model_from_pretrained_return_dict(self):
|
||||||
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
|
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True)
|
||||||
|
|
||||||
def test_save_and_load_from_pretrained(self):
|
def test_save_and_load_from_pretrained(self):
|
||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
@@ -466,6 +472,22 @@ class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
"labels": decoder_token_labels,
|
"labels": decoder_token_labels,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_bert2bert_summarization(self):
|
||||||
|
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
||||||
|
model.to(torch_device)
|
||||||
|
tokenizer = BertTokenizer.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
||||||
|
|
||||||
|
ARTICLE = """(CNN)Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members singing a racist chant. SAE's national chapter suspended the students, but University of Oklahoma President David Boren took it a step further, saying the university's affiliation with the fraternity is permanently done. The news is shocking, but it's not the first time SAE has faced controversy. SAE was founded March 9, 1856, at the University of Alabama, five years before the American Civil War, according to the fraternity website. When the war began, the group had fewer than 400 members, of which "369 went to war for the Confederate States and seven for the Union Army," the website says. The fraternity now boasts more than 200,000 living alumni, along with about 15,000 undergraduates populating 219 chapters and 20 "colonies" seeking full membership at universities. SAE has had to work hard to change recently after a string of member deaths, many blamed on the hazing of new recruits, SAE national President Bradley Cohen wrote in a message on the fraternity's website. The fraternity's website lists more than 130 chapters cited or suspended for "health and safety incidents" since 2010. At least 30 of the incidents involved hazing, and dozens more involved alcohol. However, the list is missing numerous incidents from recent months. Among them, according to various media outlets: Yale University banned the SAEs from campus activities last month after members allegedly tried to interfere with a sexual misconduct investigation connected to an initiation rite. Stanford University in December suspended SAE housing privileges after finding sorority members attending a fraternity function were subjected to graphic sexual content. And Johns Hopkins University in November suspended the fraternity for underage drinking. "The media has labeled us as the 'nation's deadliest fraternity,' " Cohen said. In 2011, for example, a student died while being coerced into excessive alcohol consumption, according to a lawsuit. SAE's previous insurer dumped the fraternity. "As a result, we are paying Lloyd's of London the highest insurance rates in the Greek-letter world," Cohen said. Universities have turned down SAE's attempts to open new chapters, and the fraternity had to close 12 in 18 months over hazing incidents."""
|
||||||
|
|
||||||
|
EXPECTED_SUMMARY = """sae was founded in 1856, five years before the civil war. the fraternity has had to work hard to change recently. the university of oklahoma president says the university's affiliation with the fraternity is permanently done. the sae has had a string of members in recent months."""
|
||||||
|
|
||||||
|
input_ids = tokenizer(ARTICLE, return_tensors="pt").input_ids.to(torch_device)
|
||||||
|
output_ids = model.generate(input_ids)
|
||||||
|
summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertEqual(summary, EXPECTED_SUMMARY)
|
||||||
|
|
||||||
|
|
||||||
class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||||
def get_encoder_decoder_model(self, config, decoder_config):
|
def get_encoder_decoder_model(self, config, decoder_config):
|
||||||
|
|||||||
@@ -289,9 +289,9 @@ class GPT2ModelTester:
|
|||||||
}
|
}
|
||||||
|
|
||||||
result = model(**inputs)
|
result = model(**inputs)
|
||||||
self.parent.assertEqual(result.lm_loss.shape, ())
|
self.parent.assertEqual(result.loss.shape, ())
|
||||||
self.parent.assertEqual(
|
self.parent.assertEqual(
|
||||||
result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
|
result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
|
||||||
)
|
)
|
||||||
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
|
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
|
||||||
|
|
||||||
@@ -324,7 +324,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (
|
all_generative_model_classes = (
|
||||||
(GPT2LMHeadModel,) if is_torch_available() else ()
|
(GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||||
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
|
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
|
|||||||
@@ -131,8 +131,8 @@ class OpenAIGPTModelTester:
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||||
self.parent.assertEqual(result.lm_loss.shape, ())
|
self.parent.assertEqual(result.loss.shape, ())
|
||||||
self.parent.assertEqual(result.lm_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
|||||||
@@ -159,17 +159,15 @@ class T5ModelTester:
|
|||||||
)
|
)
|
||||||
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||||
decoder_output = result.last_hidden_state
|
decoder_output = result.last_hidden_state
|
||||||
decoder_past = result.decoder_past_key_values
|
decoder_past = result.past_key_values
|
||||||
encoder_output = result.encoder_last_hidden_state
|
encoder_output = result.encoder_last_hidden_state
|
||||||
|
|
||||||
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
|
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
|
||||||
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size))
|
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size))
|
||||||
self.parent.assertEqual(len(decoder_past), 2)
|
# There should be `num_layers` key value embeddings stored in decoder_past
|
||||||
self.parent.assertTrue(torch.all(decoder_past[0][0] == encoder_output))
|
self.parent.assertEqual(len(decoder_past), config.num_layers)
|
||||||
# There should be `num_layers` key value embeddings stored in decoder_past[1]
|
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
|
||||||
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
|
self.parent.assertEqual(len(decoder_past[0]), 4)
|
||||||
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
|
|
||||||
self.parent.assertEqual(len(decoder_past[1][0]), 4)
|
|
||||||
|
|
||||||
def create_and_check_with_lm_head(
|
def create_and_check_with_lm_head(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -238,7 +238,7 @@ class TFGPT2ModelTester:
|
|||||||
}
|
}
|
||||||
result = model(inputs)
|
result = model(inputs)
|
||||||
self.parent.assertEqual(
|
self.parent.assertEqual(
|
||||||
result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
|
result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
|
||||||
)
|
)
|
||||||
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
|
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
|
||||||
|
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ class TFOpenAIGPTModelTester:
|
|||||||
}
|
}
|
||||||
result = model(inputs)
|
result = model(inputs)
|
||||||
self.parent.assertEqual(
|
self.parent.assertEqual(
|
||||||
result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
|
result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
|
||||||
)
|
)
|
||||||
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
|
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
|
||||||
|
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class TFT5ModelTester:
|
|||||||
|
|
||||||
result = model(input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids)
|
result = model(input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids)
|
||||||
decoder_output = result.last_hidden_state
|
decoder_output = result.last_hidden_state
|
||||||
decoder_past = result.decoder_past_key_values
|
decoder_past = result.past_key_values
|
||||||
encoder_output = result.encoder_last_hidden_state
|
encoder_output = result.encoder_last_hidden_state
|
||||||
self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
||||||
self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
|||||||
Reference in New Issue
Block a user