[Config, Caching] Remove output_past everywhere and replace by use_cache argument (#3734)
* remove output_past from pt * make style * add optional input length for gpt2 * add use cache to prepare input * save memory in gpt2 * correct gpt2 test inputs * make past input optional for gpt2 * finish use_cache for all models * make style * delete modeling_gpt2 change in test file * correct docstring * correct is true statements for gpt2
This commit is contained in:
committed by
GitHub
parent
092cf881a5
commit
01c37dcdb5
@@ -98,7 +98,7 @@ class MultiHeadAttention(torch.nn.Module):
|
||||
x = x.reshape(batch_size, -1, self.num_heads, self.depth)
|
||||
return x.permute([0, 2, 1, 3])
|
||||
|
||||
def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None):
|
||||
def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
|
||||
batch_size = q.shape[0]
|
||||
|
||||
q = self.Wq(q)
|
||||
@@ -112,7 +112,11 @@ class MultiHeadAttention(torch.nn.Module):
|
||||
past_key, past_value = layer_past[0], layer_past[1]
|
||||
k = torch.cat((past_key, k), dim=-2)
|
||||
v = torch.cat((past_value, v), dim=-2)
|
||||
present = torch.stack((k, v))
|
||||
|
||||
if use_cache is True:
|
||||
present = torch.stack((k, v))
|
||||
else:
|
||||
present = (None,)
|
||||
|
||||
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
|
||||
scaled_attention = output[0].permute([0, 2, 1, 3])
|
||||
@@ -143,10 +147,17 @@ class EncoderLayer(torch.nn.Module):
|
||||
self.dropout1 = torch.nn.Dropout(rate)
|
||||
self.dropout2 = torch.nn.Dropout(rate)
|
||||
|
||||
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None):
|
||||
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
|
||||
normed = self.layernorm1(x)
|
||||
attn_outputs = self.multi_head_attention(
|
||||
normed, normed, normed, mask, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask
|
||||
normed,
|
||||
normed,
|
||||
normed,
|
||||
mask,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
attn_output = attn_outputs[0]
|
||||
attn_output = self.dropout1(attn_output)
|
||||
@@ -199,6 +210,7 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
|
||||
|
||||
Indices can be obtained using :class:`transformers.CTRLTokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
@@ -207,8 +219,10 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
(see `past` output below). Can be used to speed up sequential decoding.
|
||||
If `past` is used, the user can optionally input only the last `input_ids`
|
||||
(those that don't have their past given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all `input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
@@ -219,6 +233,7 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -234,6 +249,10 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
|
||||
use_cache (:obj:`bool`):
|
||||
If `use_cache` is True, `past` key value states are returned and
|
||||
can be used to speed up decoding (see `past`). Defaults to `True`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -246,7 +265,6 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_past = config.output_past
|
||||
|
||||
self.d_model_size = config.n_embd
|
||||
self.num_layers = config.n_layer
|
||||
@@ -289,6 +307,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
@@ -297,8 +316,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
Sequence of hidden-states at the last layer of the model.
|
||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
Can be used (see `past` input) to speed up sequential decoding.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -325,6 +343,17 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
|
||||
# If using past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if past is not None:
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:, -1:]
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1:]
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@@ -414,10 +443,15 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||
outputs = h(
|
||||
hidden_states, mask, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
|
||||
hidden_states,
|
||||
mask,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states, present = outputs[:2]
|
||||
if self.output_past:
|
||||
if use_cache is True:
|
||||
presents = presents + (present,)
|
||||
|
||||
if self.output_attentions:
|
||||
@@ -429,7 +463,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.output_past:
|
||||
if use_cache is True:
|
||||
outputs = outputs + (presents,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
@@ -462,7 +496,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {"input_ids": input_ids, "past": past}
|
||||
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
|
||||
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
@@ -475,6 +509,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=True,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -492,8 +527,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
Can be used (see `past` input) to speed up sequential decoding.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -527,6 +561,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
Reference in New Issue
Block a user