[Config, Caching] Remove output_past everywhere and replace by use_cache argument (#3734)
* remove output_past from pt * make style * add optional input length for gpt2 * add use cache to prepare input * save memory in gpt2 * correct gpt2 test inputs * make past input optional for gpt2 * finish use_cache for all models * make style * delete modeling_gpt2 change in test file * correct docstring * correct is true statements for gpt2
This commit is contained in:
committed by
GitHub
parent
092cf881a5
commit
01c37dcdb5
@@ -59,7 +59,7 @@ class PretrainedConfig(object):
|
||||
# Attributes with defaults
|
||||
self.output_attentions = kwargs.pop("output_attentions", False)
|
||||
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
||||
self.output_past = kwargs.pop("output_past", True) # Not used by all models
|
||||
self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
|
||||
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
||||
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
||||
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
||||
|
||||
@@ -933,7 +933,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, **kwargs):
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
|
||||
assert past is not None, "past has to be defined for encoder_outputs"
|
||||
|
||||
# first step, decoder_cached_states are empty
|
||||
@@ -947,7 +947,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
"decoder_cached_states": decoder_cached_states,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"use_cache": True, # change this to avoid caching (presumably for debugging)
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||
@@ -980,10 +980,6 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
def get_output_embeddings(self):
|
||||
return _make_linear_from_emb(self.model.shared) # make it on the fly
|
||||
|
||||
def _do_output_past(self, *args, **kwargs):
|
||||
""" We should always use the cache in generate."""
|
||||
return True
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
|
||||
|
||||
@@ -98,7 +98,7 @@ class MultiHeadAttention(torch.nn.Module):
|
||||
x = x.reshape(batch_size, -1, self.num_heads, self.depth)
|
||||
return x.permute([0, 2, 1, 3])
|
||||
|
||||
def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None):
|
||||
def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
|
||||
batch_size = q.shape[0]
|
||||
|
||||
q = self.Wq(q)
|
||||
@@ -112,7 +112,11 @@ class MultiHeadAttention(torch.nn.Module):
|
||||
past_key, past_value = layer_past[0], layer_past[1]
|
||||
k = torch.cat((past_key, k), dim=-2)
|
||||
v = torch.cat((past_value, v), dim=-2)
|
||||
present = torch.stack((k, v))
|
||||
|
||||
if use_cache is True:
|
||||
present = torch.stack((k, v))
|
||||
else:
|
||||
present = (None,)
|
||||
|
||||
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
|
||||
scaled_attention = output[0].permute([0, 2, 1, 3])
|
||||
@@ -143,10 +147,17 @@ class EncoderLayer(torch.nn.Module):
|
||||
self.dropout1 = torch.nn.Dropout(rate)
|
||||
self.dropout2 = torch.nn.Dropout(rate)
|
||||
|
||||
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None):
|
||||
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
|
||||
normed = self.layernorm1(x)
|
||||
attn_outputs = self.multi_head_attention(
|
||||
normed, normed, normed, mask, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask
|
||||
normed,
|
||||
normed,
|
||||
normed,
|
||||
mask,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
attn_output = attn_outputs[0]
|
||||
attn_output = self.dropout1(attn_output)
|
||||
@@ -199,6 +210,7 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
|
||||
|
||||
Indices can be obtained using :class:`transformers.CTRLTokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
@@ -207,8 +219,10 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
(see `past` output below). Can be used to speed up sequential decoding.
|
||||
If `past` is used, the user can optionally input only the last `input_ids`
|
||||
(those that don't have their past given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all `input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
@@ -219,6 +233,7 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -234,6 +249,10 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
|
||||
use_cache (:obj:`bool`):
|
||||
If `use_cache` is True, `past` key value states are returned and
|
||||
can be used to speed up decoding (see `past`). Defaults to `True`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -246,7 +265,6 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_past = config.output_past
|
||||
|
||||
self.d_model_size = config.n_embd
|
||||
self.num_layers = config.n_layer
|
||||
@@ -289,6 +307,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
@@ -297,8 +316,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
Sequence of hidden-states at the last layer of the model.
|
||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
Can be used (see `past` input) to speed up sequential decoding.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -325,6 +343,17 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
|
||||
# If using past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if past is not None:
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:, -1:]
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1:]
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@@ -414,10 +443,15 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||
outputs = h(
|
||||
hidden_states, mask, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
|
||||
hidden_states,
|
||||
mask,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states, present = outputs[:2]
|
||||
if self.output_past:
|
||||
if use_cache is True:
|
||||
presents = presents + (present,)
|
||||
|
||||
if self.output_attentions:
|
||||
@@ -429,7 +463,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.output_past:
|
||||
if use_cache is True:
|
||||
outputs = outputs + (presents,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
@@ -462,7 +496,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {"input_ids": input_ids, "past": past}
|
||||
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
|
||||
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
@@ -475,6 +509,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=True,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -492,8 +527,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
Can be used (see `past` input) to speed up sequential decoding.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -527,6 +561,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
@@ -177,7 +177,7 @@ class Attention(nn.Module):
|
||||
else:
|
||||
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||
|
||||
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
|
||||
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
|
||||
x = self.c_attn(x)
|
||||
query, key, value = x.split(self.split_size, dim=2)
|
||||
query = self.split_heads(query)
|
||||
@@ -187,7 +187,11 @@ class Attention(nn.Module):
|
||||
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
|
||||
key = torch.cat((past_key, key), dim=-1)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
||||
|
||||
if use_cache is True:
|
||||
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
||||
else:
|
||||
present = (None,)
|
||||
|
||||
attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
|
||||
a = attn_outputs[0]
|
||||
@@ -224,9 +228,13 @@ class Block(nn.Module):
|
||||
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
self.mlp = MLP(4 * nx, config)
|
||||
|
||||
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
|
||||
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
|
||||
output_attn = self.attn(
|
||||
self.ln_1(x), layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask
|
||||
self.ln_1(x),
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
a = output_attn[0] # output_attn: a, present, (attentions)
|
||||
|
||||
@@ -279,10 +287,9 @@ GPT2_START_DOCSTRING = r"""
|
||||
|
||||
GPT2_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
|
||||
`input_ids_length` = `sequence_length if `past` is None else 1
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
If using `past` as an input make sure that `input_ids` are those of the last position.
|
||||
If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
|
||||
|
||||
Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
@@ -292,8 +299,8 @@ GPT2_INPUTS_DOCSTRING = r"""
|
||||
|
||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
(see `past` output below). Can be used to speed up sequential decoding.
|
||||
If `past` is used, the user can optionally input only the last `input_ids` (those that don't have their past given to this model) of shape :obj:`(batch_size, 1)` instead of all `input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
@@ -305,7 +312,7 @@ GPT2_INPUTS_DOCSTRING = r"""
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
If using `past` as an input make sure that `token_type_ids` correspond to the `input_ids` of the last position.
|
||||
If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -321,6 +328,9 @@ GPT2_INPUTS_DOCSTRING = r"""
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
|
||||
use_cache (:obj:`bool`):
|
||||
If `use_cache` is True, `past` key value states are returned and can be used to speed up decoding (see `past`). Defaults to `True`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -333,7 +343,6 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_past = config.output_past
|
||||
|
||||
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
||||
@@ -366,16 +375,17 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the last layer of the model.
|
||||
If `past` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
|
||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
Can be used (see `past` input) to speed up sequential decoding.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -400,6 +410,17 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
|
||||
# If using past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if past is not None:
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:, -1:]
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1:]
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@@ -484,11 +505,15 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||
|
||||
outputs = block(
|
||||
hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states, present = outputs[:2]
|
||||
if self.output_past:
|
||||
if use_cache is True:
|
||||
presents = presents + (present,)
|
||||
|
||||
if self.output_attentions:
|
||||
@@ -502,7 +527,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.output_past:
|
||||
if use_cache is True:
|
||||
outputs = outputs + (presents,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
@@ -535,7 +560,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {"input_ids": input_ids, "past": past}
|
||||
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
|
||||
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
@@ -548,6 +573,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=True,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -565,8 +591,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
Can be used (see `past` input) to speed up sequential decoding.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -600,6 +625,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
@@ -652,6 +678,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
mc_token_ids=None,
|
||||
lm_labels=None,
|
||||
mc_labels=None,
|
||||
use_cache=True,
|
||||
):
|
||||
r"""
|
||||
mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
|
||||
@@ -680,8 +707,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
Can be used (see `past` input) to speed up sequential decoding.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -726,6 +752,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
@@ -188,7 +188,6 @@ class T5Attention(nn.Module):
|
||||
super().__init__()
|
||||
self.is_decoder = config.is_decoder
|
||||
self.has_relative_attention_bias = has_relative_attention_bias
|
||||
self.output_past = config.output_past
|
||||
|
||||
self.output_attentions = config.output_attentions
|
||||
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
||||
@@ -300,6 +299,7 @@ class T5Attention(nn.Module):
|
||||
past_key_value_state=None,
|
||||
head_mask=None,
|
||||
query_length=None,
|
||||
use_cache=False,
|
||||
):
|
||||
"""
|
||||
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
||||
@@ -351,7 +351,7 @@ class T5Attention(nn.Module):
|
||||
else:
|
||||
k, v = past_key_value_state
|
||||
|
||||
if self.is_decoder and self.output_past:
|
||||
if self.is_decoder and use_cache:
|
||||
present_key_value_state = ((k, v),)
|
||||
else:
|
||||
present_key_value_state = (None,)
|
||||
@@ -385,14 +385,8 @@ class T5Attention(nn.Module):
|
||||
|
||||
context = self.o(context)
|
||||
|
||||
outputs = (context,)
|
||||
outputs = (context,) + present_key_value_state
|
||||
|
||||
if self.output_past is False or self.is_decoder is False:
|
||||
assert (
|
||||
present_key_value_state[0] is None
|
||||
), "Key/Value projections should not be stored if {} is not decoder or output_past is False".format(self)
|
||||
|
||||
outputs = outputs + present_key_value_state
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (weights,)
|
||||
if self.has_relative_attention_bias:
|
||||
@@ -408,7 +402,13 @@ class T5LayerSelfAttention(nn.Module):
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, attention_mask=None, position_bias=None, head_mask=None, past_key_value_state=None
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
position_bias=None,
|
||||
head_mask=None,
|
||||
past_key_value_state=None,
|
||||
use_cache=False,
|
||||
):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
attention_output = self.SelfAttention(
|
||||
@@ -417,6 +417,7 @@ class T5LayerSelfAttention(nn.Module):
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
past_key_value_state=past_key_value_state,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
y = attention_output[0]
|
||||
layer_output = hidden_states + self.dropout(y)
|
||||
@@ -439,6 +440,7 @@ class T5LayerCrossAttention(nn.Module):
|
||||
position_bias=None,
|
||||
head_mask=None,
|
||||
past_key_value_state=None,
|
||||
use_cache=False,
|
||||
query_length=None,
|
||||
):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
@@ -449,6 +451,7 @@ class T5LayerCrossAttention(nn.Module):
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
past_key_value_state=past_key_value_state,
|
||||
use_cache=use_cache,
|
||||
query_length=query_length,
|
||||
)
|
||||
y = attention_output[0]
|
||||
@@ -460,7 +463,6 @@ class T5LayerCrossAttention(nn.Module):
|
||||
class T5Block(nn.Module):
|
||||
def __init__(self, config, has_relative_attention_bias=False):
|
||||
super().__init__()
|
||||
self.output_past = config.output_past
|
||||
self.is_decoder = config.is_decoder
|
||||
self.layer = nn.ModuleList()
|
||||
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
|
||||
@@ -479,6 +481,7 @@ class T5Block(nn.Module):
|
||||
encoder_decoder_position_bias=None,
|
||||
head_mask=None,
|
||||
past_key_value_state=None,
|
||||
use_cache=False,
|
||||
):
|
||||
|
||||
if past_key_value_state is not None:
|
||||
@@ -499,6 +502,7 @@ class T5Block(nn.Module):
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
past_key_value_state=self_attn_past_key_value_state,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states, present_key_value_state = self_attention_outputs[:2]
|
||||
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
|
||||
@@ -519,6 +523,7 @@ class T5Block(nn.Module):
|
||||
head_mask=head_mask,
|
||||
past_key_value_state=cross_attn_past_key_value_state,
|
||||
query_length=query_length,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states = cross_attention_outputs[0]
|
||||
# Combine self attn and cross attn key value states
|
||||
@@ -620,7 +625,6 @@ class T5Stack(T5PreTrainedModel):
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
self.is_decoder = config.is_decoder
|
||||
self.output_past = config.output_past and self.is_decoder
|
||||
|
||||
self.block = nn.ModuleList(
|
||||
[T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
|
||||
@@ -648,6 +652,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
inputs_embeds=None,
|
||||
head_mask=None,
|
||||
past_key_value_states=None,
|
||||
use_cache=False,
|
||||
):
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
@@ -699,7 +704,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, mask_seq_length, 1) <= seq_ids[None, :, None]
|
||||
causal_mask = causal_mask.to(attention_mask)
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
if self.output_past and past_key_value_states[0] is not None:
|
||||
if past_key_value_states[0] is not None:
|
||||
extended_attention_mask = extended_attention_mask[:, :, -1:, :]
|
||||
else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
@@ -776,6 +781,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
head_mask=head_mask[i],
|
||||
past_key_value_state=past_key_value_state,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
# layer_outputs is a tuple with:
|
||||
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
@@ -800,7 +806,8 @@ class T5Stack(T5PreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.is_decoder and self.output_past:
|
||||
if use_cache is True:
|
||||
assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self)
|
||||
outputs = outputs + (present_key_value_states,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
@@ -833,7 +840,7 @@ T5_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
T5 is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left. If `decoder_past_key_value_states` is used, optionally only the last `input_ids` have to be input (see `decoder_past_key_value_states`).
|
||||
T5 is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left.
|
||||
Indices can be obtained using :class:`transformers.T5Tokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
@@ -849,19 +856,26 @@ T5_INPUTS_DOCSTRING = r"""
|
||||
Used in the cross-attention of the decoder.
|
||||
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
|
||||
If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`).
|
||||
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
|
||||
`T5 Training <./t5.html#training>`_ .
|
||||
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||
Can be used to speed up decoding. If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids` of shape :obj:`(batch_size, 1)` instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
Can be used to speed up decoding.
|
||||
If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids`
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`).
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
|
||||
If `decoder_past_key_value_states` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_value_states`).
|
||||
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
head_mask: (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -897,14 +911,6 @@ class T5Model(T5PreTrainedModel):
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
self.decoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
def set_output_past(self, do_output_past: bool):
|
||||
self.config.output_past = do_output_past
|
||||
self.decoder.output_past = do_output_past
|
||||
for block in self.decoder.block:
|
||||
block.output_past = do_output_past
|
||||
block.layer[0].SelfAttention.output_past = do_output_past
|
||||
block.layer[1].EncDecAttention.output_past = do_output_past
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
@@ -928,6 +934,7 @@ class T5Model(T5PreTrainedModel):
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_past_key_value_states=None,
|
||||
use_cache=True,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
head_mask=None,
|
||||
@@ -938,7 +945,7 @@ class T5Model(T5PreTrainedModel):
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
If `decoder_past_key_value_states` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
|
||||
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``config.output_past=True``):
|
||||
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
|
||||
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
|
||||
Note that when using `decoder_past_key_value_states`, the model only outputs the last `hidden-state` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
|
||||
@@ -976,7 +983,7 @@ class T5Model(T5PreTrainedModel):
|
||||
|
||||
# If decoding with past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if decoder_past_key_value_states is not None and self.decoder.output_past is True:
|
||||
if decoder_past_key_value_states is not None:
|
||||
if decoder_input_ids is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
if decoder_inputs_embeds is not None:
|
||||
@@ -991,9 +998,10 @@ class T5Model(T5PreTrainedModel):
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
if self.decoder.output_past:
|
||||
if use_cache is True:
|
||||
past = ((encoder_outputs, decoder_outputs[1]),)
|
||||
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
|
||||
|
||||
@@ -1022,14 +1030,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_output_past(self, do_output_past: bool):
|
||||
self.config.output_past = do_output_past
|
||||
self.decoder.output_past = do_output_past
|
||||
for block in self.decoder.block:
|
||||
block.output_past = do_output_past
|
||||
block.layer[0].SelfAttention.output_past = do_output_past
|
||||
block.layer[1].EncDecAttention.output_past = do_output_past
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared = new_embeddings
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
@@ -1053,6 +1053,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_past_key_value_states=None,
|
||||
use_cache=True,
|
||||
lm_labels=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
@@ -1072,7 +1073,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
If `past_key_value_states` is used only the last prediction_scores of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
|
||||
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``config.output_past=True``):
|
||||
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
|
||||
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
|
||||
Note that when using `decoder_past_key_value_states`, the model only outputs the last `prediction_score` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
|
||||
@@ -1116,10 +1117,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
|
||||
# If decoding with past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if decoder_past_key_value_states is not None and self.decoder.output_past is True:
|
||||
assert (
|
||||
lm_labels is None
|
||||
), "Decoder should not use cached key value states when training. Also consider setting model.set_output_past(False) for less memory consumption"
|
||||
if decoder_past_key_value_states is not None:
|
||||
assert lm_labels is None, "Decoder should not use cached key value states when training."
|
||||
if decoder_input_ids is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
if decoder_inputs_embeds is not None:
|
||||
@@ -1134,11 +1133,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# insert decoder past at right place
|
||||
# to speed up decoding
|
||||
if self.decoder.output_past:
|
||||
if use_cache is True:
|
||||
past = ((encoder_outputs, decoder_outputs[1]),)
|
||||
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
|
||||
|
||||
@@ -1157,7 +1157,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs):
|
||||
assert past is not None, "past has to be defined for encoder_outputs"
|
||||
|
||||
# first step
|
||||
@@ -1171,13 +1171,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
"decoder_past_key_value_states": decoder_past_key_value_states,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
# if decoder past is not included in output
|
||||
# speedy decoding is disabled and no need to reorder
|
||||
if len(past) < 2:
|
||||
logger.warning("You might want to consider setting model.set_output_past(True) to speed up decoding")
|
||||
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
|
||||
return past
|
||||
|
||||
decoder_past = past[1]
|
||||
|
||||
@@ -94,7 +94,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
v, k, q, mask, layer_past, attention_mask, head_mask = inputs
|
||||
v, k, q, mask, layer_past, attention_mask, head_mask, use_cache = inputs
|
||||
batch_size = shape_list(q)[0]
|
||||
|
||||
q = self.Wq(q)
|
||||
@@ -104,11 +104,25 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
q = self.split_into_heads(q, batch_size)
|
||||
k = self.split_into_heads(k, batch_size)
|
||||
v = self.split_into_heads(v, batch_size)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = tf.unstack(layer_past, axis=0)
|
||||
k = tf.concat((past_key, k), axis=-2)
|
||||
v = tf.concat((past_value, v), axis=-2)
|
||||
present = tf.stack((k, v), axis=0)
|
||||
|
||||
# to cope with keras serialization
|
||||
# we need to cast `use_cache` to correct bool
|
||||
# if it is a tensor
|
||||
if tf.is_tensor(use_cache):
|
||||
if hasattr(use_cache, "numpy"):
|
||||
use_cache = bool(use_cache.numpy())
|
||||
else:
|
||||
use_cache = True
|
||||
|
||||
if use_cache is True:
|
||||
present = tf.stack((k, v), axis=0)
|
||||
else:
|
||||
present = (None,)
|
||||
|
||||
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
|
||||
scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])
|
||||
@@ -147,10 +161,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
||||
self.dropout2 = tf.keras.layers.Dropout(rate)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
x, mask, layer_past, attention_mask, head_mask = inputs
|
||||
x, mask, layer_past, attention_mask, head_mask, use_cache = inputs
|
||||
normed = self.layernorm1(x)
|
||||
attn_outputs = self.multi_head_attention(
|
||||
[normed, normed, normed, mask, layer_past, attention_mask, head_mask], training=training
|
||||
[normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache], training=training
|
||||
)
|
||||
attn_output = attn_outputs[0]
|
||||
attn_output = self.dropout1(attn_output, training=training)
|
||||
@@ -173,7 +187,6 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_past = config.output_past
|
||||
|
||||
self.d_model_size = config.n_embd
|
||||
self.num_layers = config.n_layer
|
||||
@@ -220,8 +233,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
training=False,
|
||||
):
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
past = inputs[1] if len(inputs) > 1 else past
|
||||
@@ -230,7 +245,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
use_cache = inputs[7] if len(inputs) > 7 else use_cache
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
past = inputs.get("past", past)
|
||||
@@ -239,10 +255,21 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
# If using past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if past is not None:
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:, -1:]
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1:]
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@@ -319,10 +346,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
for i, (h, layer_past) in enumerate(zip(self.h, past)):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i]], training=training)
|
||||
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache], training=training)
|
||||
hidden_states, present = outputs[:2]
|
||||
|
||||
if self.output_past:
|
||||
if use_cache is True:
|
||||
presents = presents + (present,)
|
||||
|
||||
if self.output_attentions:
|
||||
@@ -334,7 +361,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.output_past:
|
||||
if use_cache is True:
|
||||
outputs = outputs + (presents,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
@@ -386,6 +413,7 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
|
||||
|
||||
Indices can be obtained using :class:`transformers.CTRLTokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
@@ -394,8 +422,10 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
(see `past` output below). Can be used to speed up sequential decoding.
|
||||
If `past` is used, the user can optionally input only the last `input_ids`
|
||||
(those that don't have their past given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all `input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
attention_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
@@ -406,6 +436,7 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -421,6 +452,10 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
|
||||
use_cache (:obj:`bool`):
|
||||
If `use_cache` is True, `past` key value states are returned and
|
||||
can be used to speed up decoding (see `past`). Defaults to `True`.
|
||||
training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
|
||||
Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
|
||||
(if set to :obj:`False`) for evaluation.
|
||||
@@ -514,7 +549,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
|
||||
if past:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
|
||||
return {"inputs": inputs, "past": past}
|
||||
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
|
||||
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
||||
def call(self, inputs, **kwargs):
|
||||
|
||||
@@ -134,7 +134,7 @@ class TFAttention(tf.keras.layers.Layer):
|
||||
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
x, layer_past, attention_mask, head_mask = inputs
|
||||
x, layer_past, attention_mask, head_mask, use_cache = inputs
|
||||
|
||||
x = self.c_attn(x)
|
||||
query, key, value = tf.split(x, 3, axis=2)
|
||||
@@ -145,7 +145,20 @@ class TFAttention(tf.keras.layers.Layer):
|
||||
past_key, past_value = tf.unstack(layer_past, axis=0)
|
||||
key = tf.concat([past_key, key], axis=-2)
|
||||
value = tf.concat([past_value, value], axis=-2)
|
||||
present = tf.stack([key, value], axis=0)
|
||||
|
||||
# to cope with keras serialization
|
||||
# we need to cast `use_cache` to correct bool
|
||||
# if it is a tensor
|
||||
if tf.is_tensor(use_cache):
|
||||
if hasattr(use_cache, "numpy"):
|
||||
use_cache = bool(use_cache.numpy())
|
||||
else:
|
||||
use_cache = True
|
||||
|
||||
if use_cache is True:
|
||||
present = tf.stack([key, value], axis=0)
|
||||
else:
|
||||
present = (None,)
|
||||
|
||||
attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
|
||||
a = attn_outputs[0]
|
||||
@@ -184,10 +197,10 @@ class TFBlock(tf.keras.layers.Layer):
|
||||
self.mlp = TFMLP(4 * nx, config, name="mlp")
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
x, layer_past, attention_mask, head_mask = inputs
|
||||
x, layer_past, attention_mask, head_mask, use_cache = inputs
|
||||
|
||||
a = self.ln_1(x)
|
||||
output_attn = self.attn([a, layer_past, attention_mask, head_mask], training=training)
|
||||
output_attn = self.attn([a, layer_past, attention_mask, head_mask, use_cache], training=training)
|
||||
a = output_attn[0] # output_attn: a, present, (attentions)
|
||||
x = x + a
|
||||
|
||||
@@ -245,6 +258,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
training=False,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
@@ -255,7 +269,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
use_cache = inputs[7] if len(inputs) > 7 else use_cache
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
past = inputs.get("past", past)
|
||||
@@ -264,10 +279,21 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
# If using past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if past is not None:
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:, -1:]
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1:]
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@@ -338,7 +364,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||
|
||||
outputs = block([hidden_states, layer_past, attention_mask, head_mask[i]], training=training)
|
||||
outputs = block([hidden_states, layer_past, attention_mask, head_mask[i], use_cache], training=training)
|
||||
|
||||
hidden_states, present = outputs[:2]
|
||||
presents = presents + (present,)
|
||||
@@ -353,7 +379,10 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states, presents)
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if use_cache is True:
|
||||
outputs = outputs + (presents,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
@@ -404,6 +433,7 @@ GPT2_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
|
||||
|
||||
Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
@@ -424,6 +454,7 @@ GPT2_INPUTS_DOCSTRING = r"""
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -439,6 +470,7 @@ GPT2_INPUTS_DOCSTRING = r"""
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
|
||||
training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
|
||||
Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
|
||||
(if set to :obj:`False`) for evaluation.
|
||||
@@ -511,7 +543,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
|
||||
if past:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
|
||||
return {"inputs": inputs, "past": past}
|
||||
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
|
||||
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
||||
def call(self, inputs, **kwargs):
|
||||
@@ -590,6 +622,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
mc_token_ids=None,
|
||||
use_cache=True,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -656,7 +689,8 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
use_cache = inputs[8] if len(inputs) > 8 else use_cache
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get("input_ids")
|
||||
past = inputs.get("past", past)
|
||||
@@ -666,7 +700,8 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
@@ -690,6 +725,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
flat_position_ids,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
use_cache,
|
||||
]
|
||||
|
||||
transformer_outputs = self.transformer(flat_inputs, training=training)
|
||||
|
||||
@@ -444,16 +444,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
def prepare_inputs_for_generation(self, inputs, **kwargs):
|
||||
return {"inputs": inputs}
|
||||
|
||||
def _do_output_past(self, outputs):
|
||||
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
|
||||
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
|
||||
|
||||
if has_output_past and not has_mem_len and len(outputs) > 1:
|
||||
return True
|
||||
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
|
||||
return True
|
||||
|
||||
return False
|
||||
def _use_cache(self, outputs, use_cache):
|
||||
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
|
||||
if len(outputs) <= 1 or use_cache is False:
|
||||
return False
|
||||
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@@ -476,6 +473,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
num_return_sequences=None,
|
||||
attention_mask=None,
|
||||
decoder_start_token_id=None,
|
||||
use_cache=None,
|
||||
):
|
||||
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||
and beam-search.
|
||||
@@ -551,6 +549,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
If an encoder-decoder model starts decoding with a different token than BOS.
|
||||
Defaults to `None` and is changed to `BOS` later.
|
||||
|
||||
use_cache: (`optional`) bool
|
||||
If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
|
||||
|
||||
Return:
|
||||
|
||||
output: `tf.Tensor` of `dtype=tf.int32` shape `(batch_size * num_return_sequences, sequence_length)`
|
||||
@@ -605,6 +606,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
temperature = temperature if temperature is not None else self.config.temperature
|
||||
top_k = top_k if top_k is not None else self.config.top_k
|
||||
@@ -634,6 +636,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
||||
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
|
||||
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
|
||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
|
||||
assert temperature > 0, "`temperature` should be strictely positive."
|
||||
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
||||
@@ -782,6 +785,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
vocab_size=vocab_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
else:
|
||||
output = self._generate_no_beam_search(
|
||||
@@ -804,6 +808,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
vocab_size=vocab_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -829,6 +834,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
vocab_size,
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
use_cache,
|
||||
):
|
||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||
All returned sequence are generated independantly.
|
||||
@@ -841,12 +847,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
|
||||
)
|
||||
outputs = self(**model_inputs)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._do_output_past(outputs):
|
||||
if self._use_cache(outputs, use_cache):
|
||||
past = outputs[1]
|
||||
|
||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||
@@ -993,6 +1001,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
vocab_size,
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
use_cache,
|
||||
):
|
||||
""" Generate sequences for each example with beam search.
|
||||
"""
|
||||
@@ -1020,12 +1029,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
done = [False for _ in range(batch_size)]
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
|
||||
)
|
||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._do_output_past(outputs):
|
||||
if self._use_cache(outputs, use_cache):
|
||||
past = outputs[1]
|
||||
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
|
||||
@@ -358,7 +358,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_past = config.output_past
|
||||
|
||||
self.mem_len = config.mem_len
|
||||
self.reuse_len = config.reuse_len
|
||||
@@ -503,6 +502,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
input_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
training=False,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
@@ -515,7 +515,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
input_mask = inputs[6] if len(inputs) > 6 else input_mask
|
||||
head_mask = inputs[7] if len(inputs) > 7 else head_mask
|
||||
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
use_cache = inputs[9] if len(inputs) > 9 else use_cache
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
@@ -526,7 +527,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
input_mask = inputs.get("input_mask", input_mask)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
@@ -657,7 +659,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
hidden_states = []
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
# cache new mems
|
||||
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
|
||||
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
|
||||
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
|
||||
if self.output_hidden_states:
|
||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||
@@ -679,7 +681,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
|
||||
outputs = (tf.transpose(output, perm=(1, 0, 2)),)
|
||||
|
||||
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
|
||||
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
|
||||
outputs = outputs + (new_mems,)
|
||||
|
||||
if self.output_hidden_states:
|
||||
@@ -783,6 +785,8 @@ XLNET_INPUTS_DOCSTRING = r"""
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
use_cache (:obj:`bool`):
|
||||
If `use_cache` is True, `mems` are returned and can be used to speed up decoding (see `mems`). Defaults to `True`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -848,7 +852,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_loss.input_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
|
||||
# Add dummy token at the end (no attention on this one)
|
||||
|
||||
effective_batch_size = inputs.shape[0]
|
||||
@@ -866,7 +870,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
|
||||
target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1), dtype=tf.float32)
|
||||
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
|
||||
|
||||
inputs = {"inputs": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping}
|
||||
inputs = {
|
||||
"inputs": inputs,
|
||||
"perm_mask": perm_mask,
|
||||
"target_mapping": target_mapping,
|
||||
"use_cache": kwargs["use_cache"],
|
||||
}
|
||||
|
||||
# if past is defined in model kwargs then use it for faster decoding
|
||||
if past:
|
||||
|
||||
@@ -652,15 +652,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
def prepare_scores_for_generation(self, scores, **kwargs):
|
||||
return scores
|
||||
|
||||
def _do_output_past(self, outputs):
|
||||
def _use_cache(self, outputs, use_cache):
|
||||
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
|
||||
has_output_past = getattr(self.config, "output_past", False)
|
||||
mem_len = getattr(self.config, "mem_len", 0)
|
||||
if len(outputs) <= 1:
|
||||
if len(outputs) <= 1 or use_cache is False:
|
||||
return False
|
||||
if mem_len > 0 or has_output_past:
|
||||
return True
|
||||
return False
|
||||
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
|
||||
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
|
||||
@@ -694,6 +692,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
num_return_sequences=None,
|
||||
attention_mask=None,
|
||||
decoder_start_token_id=None,
|
||||
use_cache=None,
|
||||
):
|
||||
r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
|
||||
|
||||
@@ -768,6 +767,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
If an encoder-decoder model starts decoding with a different token than BOS.
|
||||
Defaults to `None` and is changed to `BOS` later.
|
||||
|
||||
use_cache: (`optional`) bool
|
||||
If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
|
||||
|
||||
Return:
|
||||
|
||||
output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
|
||||
@@ -822,6 +824,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
temperature = temperature if temperature is not None else self.config.temperature
|
||||
top_k = top_k if top_k is not None else self.config.top_k
|
||||
@@ -851,6 +854,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
||||
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
|
||||
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
|
||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
|
||||
assert temperature > 0, "`temperature` should be strictly positive."
|
||||
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
||||
@@ -1011,6 +1015,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
vocab_size=vocab_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
else:
|
||||
output = self._generate_no_beam_search(
|
||||
@@ -1032,6 +1037,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
batch_size=effective_batch_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -1056,6 +1062,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
batch_size,
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
use_cache,
|
||||
):
|
||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||
All returned sequence are generated independantly.
|
||||
@@ -1067,13 +1074,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
|
||||
)
|
||||
|
||||
outputs = self(**model_inputs)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._do_output_past(outputs):
|
||||
if self._use_cache(outputs, use_cache):
|
||||
past = outputs[1]
|
||||
|
||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||
@@ -1178,6 +1187,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
vocab_size,
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
use_cache,
|
||||
):
|
||||
""" Generate sequences for each example with beam search.
|
||||
"""
|
||||
@@ -1203,12 +1213,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
done = [False for _ in range(batch_size)]
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
|
||||
)
|
||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._do_output_past(outputs):
|
||||
if self._use_cache(outputs, use_cache):
|
||||
past = outputs[1]
|
||||
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
|
||||
@@ -524,6 +524,7 @@ XLNET_INPUTS_DOCSTRING = r"""
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
|
||||
given to this model should not be passed as input ids as they have already been computed.
|
||||
`use_cache` has to be set to `True` to make use of `mems`.
|
||||
perm_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
|
||||
If ``perm_mask[k, i, j] = 0``, i attend to j in batch k;
|
||||
@@ -555,6 +556,8 @@ XLNET_INPUTS_DOCSTRING = r"""
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
use_cache (:obj:`bool`):
|
||||
If `use_cache` is True, `mems` are returned and can be used to speed up decoding (see `mems`). Defaults to `True`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -567,7 +570,6 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_past = config.output_past
|
||||
|
||||
self.mem_len = config.mem_len
|
||||
self.reuse_len = config.reuse_len
|
||||
@@ -698,6 +700,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
input_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
@@ -864,7 +867,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
attentions = []
|
||||
hidden_states = []
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
|
||||
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
|
||||
# cache new mems
|
||||
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
|
||||
if self.output_hidden_states:
|
||||
@@ -894,7 +897,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
|
||||
outputs = (output.permute(1, 0, 2).contiguous(),)
|
||||
|
||||
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
|
||||
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
|
||||
outputs = outputs + (new_mems,)
|
||||
|
||||
if self.output_hidden_states:
|
||||
@@ -935,7 +938,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_loss
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
|
||||
# Add dummy token at the end (no attention on this one)
|
||||
|
||||
effective_batch_size = input_ids.shape[0]
|
||||
@@ -955,7 +958,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
)
|
||||
target_mapping[0, 0, -1] = 1.0
|
||||
|
||||
inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping}
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"perm_mask": perm_mask,
|
||||
"target_mapping": target_mapping,
|
||||
"use_cache": kwargs["use_cache"],
|
||||
}
|
||||
|
||||
# if past is defined in model kwargs then use it for faster decoding
|
||||
if past:
|
||||
@@ -975,6 +983,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
input_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
@@ -1050,6 +1059,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
input_mask=input_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
logits = self.lm_loss(transformer_outputs[0])
|
||||
@@ -1093,6 +1103,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
input_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
@@ -1148,6 +1159,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
input_mask=input_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
output = transformer_outputs[0]
|
||||
|
||||
@@ -1196,6 +1208,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
|
||||
input_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
@@ -1252,6 +1265,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
|
||||
input_mask=input_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@@ -1301,9 +1315,10 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
||||
mems=None,
|
||||
perm_mask=None,
|
||||
target_mapping=None,
|
||||
labels=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -1368,6 +1383,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
||||
target_mapping=target_mapping,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
output = transformer_outputs[0]
|
||||
@@ -1414,6 +1430,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
|
||||
input_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
):
|
||||
@@ -1478,6 +1495,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
|
||||
input_mask=input_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@@ -1538,6 +1556,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
input_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
is_impossible=None,
|
||||
@@ -1616,6 +1635,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
input_mask=input_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
||||
|
||||
Reference in New Issue
Block a user