[Config, Caching] Remove output_past everywhere and replace by use_cache argument (#3734)

* remove output_past from pt

* make style

* add optional input length for gpt2

* add use cache to prepare input

* save memory in gpt2

* correct gpt2 test inputs

* make past input optional for gpt2

* finish use_cache for all models

* make style

* delete modeling_gpt2 change in test file

* correct docstring

* correct is true statements for gpt2
This commit is contained in:
Patrick von Platen
2020-04-14 20:40:28 +02:00
committed by GitHub
parent 092cf881a5
commit 01c37dcdb5
15 changed files with 342 additions and 168 deletions

View File

@@ -94,7 +94,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=False):
v, k, q, mask, layer_past, attention_mask, head_mask = inputs
v, k, q, mask, layer_past, attention_mask, head_mask, use_cache = inputs
batch_size = shape_list(q)[0]
q = self.Wq(q)
@@ -104,11 +104,25 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
q = self.split_into_heads(q, batch_size)
k = self.split_into_heads(k, batch_size)
v = self.split_into_heads(v, batch_size)
if layer_past is not None:
past_key, past_value = tf.unstack(layer_past, axis=0)
k = tf.concat((past_key, k), axis=-2)
v = tf.concat((past_value, v), axis=-2)
present = tf.stack((k, v), axis=0)
# to cope with keras serialization
# we need to cast `use_cache` to correct bool
# if it is a tensor
if tf.is_tensor(use_cache):
if hasattr(use_cache, "numpy"):
use_cache = bool(use_cache.numpy())
else:
use_cache = True
if use_cache is True:
present = tf.stack((k, v), axis=0)
else:
present = (None,)
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])
@@ -147,10 +161,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, inputs, training=False):
x, mask, layer_past, attention_mask, head_mask = inputs
x, mask, layer_past, attention_mask, head_mask, use_cache = inputs
normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention(
[normed, normed, normed, mask, layer_past, attention_mask, head_mask], training=training
[normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache], training=training
)
attn_output = attn_outputs[0]
attn_output = self.dropout1(attn_output, training=training)
@@ -173,7 +187,6 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.d_model_size = config.n_embd
self.num_layers = config.n_layer
@@ -220,8 +233,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
training=False,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past
@@ -230,7 +245,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs."
use_cache = inputs[7] if len(inputs) > 7 else use_cache
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
@@ -239,10 +255,21 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 7, "Too many inputs."
use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
# If using past key value states, only the last tokens
# should be given as an input
if past is not None:
if input_ids is not None:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, -1:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1:]
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@@ -319,10 +346,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
for i, (h, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i]], training=training)
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache], training=training)
hidden_states, present = outputs[:2]
if self.output_past:
if use_cache is True:
presents = presents + (present,)
if self.output_attentions:
@@ -334,7 +361,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_past:
if use_cache is True:
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
@@ -386,6 +413,7 @@ CTRL_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
Indices can be obtained using :class:`transformers.CTRLTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
@@ -394,8 +422,10 @@ CTRL_INPUTS_DOCSTRING = r"""
`What are input IDs? <../glossary.html#input-ids>`__
past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
(see `past` output below). Can be used to speed up sequential decoding.
If `past` is used, the user can optionally input only the last `input_ids`
(those that don't have their past given to this model) of shape :obj:`(batch_size, 1)`
instead of all `input_ids` of shape :obj:`(batch_size, sequence_length)`.
attention_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
@@ -406,6 +436,7 @@ CTRL_INPUTS_DOCSTRING = r"""
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@@ -421,6 +452,10 @@ CTRL_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
use_cache (:obj:`bool`):
If `use_cache` is True, `past` key value states are returned and
can be used to speed up decoding (see `past`). Defaults to `True`.
training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
(if set to :obj:`False`) for evaluation.
@@ -514,7 +549,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {"inputs": inputs, "past": past}
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):