[Config, Caching] Remove output_past everywhere and replace by use_cache argument (#3734)

* remove output_past from pt

* make style

* add optional input length for gpt2

* add use cache to prepare input

* save memory in gpt2

* correct gpt2 test inputs

* make past input optional for gpt2

* finish use_cache for all models

* make style

* delete modeling_gpt2 change in test file

* correct docstring

* correct is true statements for gpt2
This commit is contained in:
Patrick von Platen
2020-04-14 20:40:28 +02:00
committed by GitHub
parent 092cf881a5
commit 01c37dcdb5
15 changed files with 342 additions and 168 deletions

View File

@@ -98,7 +98,7 @@ class MultiHeadAttention(torch.nn.Module):
x = x.reshape(batch_size, -1, self.num_heads, self.depth)
return x.permute([0, 2, 1, 3])
def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None):
def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
batch_size = q.shape[0]
q = self.Wq(q)
@@ -112,7 +112,11 @@ class MultiHeadAttention(torch.nn.Module):
past_key, past_value = layer_past[0], layer_past[1]
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)
present = torch.stack((k, v))
if use_cache is True:
present = torch.stack((k, v))
else:
present = (None,)
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
scaled_attention = output[0].permute([0, 2, 1, 3])
@@ -143,10 +147,17 @@ class EncoderLayer(torch.nn.Module):
self.dropout1 = torch.nn.Dropout(rate)
self.dropout2 = torch.nn.Dropout(rate)
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None):
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention(
normed, normed, normed, mask, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask
normed,
normed,
normed,
mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
)
attn_output = attn_outputs[0]
attn_output = self.dropout1(attn_output)
@@ -199,6 +210,7 @@ CTRL_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
Indices can be obtained using :class:`transformers.CTRLTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
@@ -207,8 +219,10 @@ CTRL_INPUTS_DOCSTRING = r"""
`What are input IDs? <../glossary.html#input-ids>`__
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
(see `past` output below). Can be used to speed up sequential decoding.
If `past` is used, the user can optionally input only the last `input_ids`
(those that don't have their past given to this model) of shape :obj:`(batch_size, 1)`
instead of all `input_ids` of shape :obj:`(batch_size, sequence_length)`.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
@@ -219,6 +233,7 @@ CTRL_INPUTS_DOCSTRING = r"""
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@@ -234,6 +249,10 @@ CTRL_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
use_cache (:obj:`bool`):
If `use_cache` is True, `past` key value states are returned and
can be used to speed up decoding (see `past`). Defaults to `True`.
"""
@@ -246,7 +265,6 @@ class CTRLModel(CTRLPreTrainedModel):
super().__init__(config)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.d_model_size = config.n_embd
self.num_layers = config.n_layer
@@ -289,6 +307,7 @@ class CTRLModel(CTRLPreTrainedModel):
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
):
r"""
Return:
@@ -297,8 +316,7 @@ class CTRLModel(CTRLPreTrainedModel):
Sequence of hidden-states at the last layer of the model.
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -325,6 +343,17 @@ class CTRLModel(CTRLPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
# If using past key value states, only the last tokens
# should be given as an input
if past is not None:
if input_ids is not None:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, -1:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1:]
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@@ -414,10 +443,15 @@ class CTRLModel(CTRLPreTrainedModel):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = h(
hidden_states, mask, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
hidden_states,
mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
)
hidden_states, present = outputs[:2]
if self.output_past:
if use_cache is True:
presents = presents + (present,)
if self.output_attentions:
@@ -429,7 +463,7 @@ class CTRLModel(CTRLPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_past:
if use_cache is True:
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
@@ -462,7 +496,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past": past}
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
def forward(
@@ -475,6 +509,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=True,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@@ -492,8 +527,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -527,6 +561,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
)
hidden_states = transformer_outputs[0]