[Config, Caching] Remove output_past everywhere and replace by use_cache argument (#3734)
* remove output_past from pt * make style * add optional input length for gpt2 * add use cache to prepare input * save memory in gpt2 * correct gpt2 test inputs * make past input optional for gpt2 * finish use_cache for all models * make style * delete modeling_gpt2 change in test file * correct docstring * correct is true statements for gpt2
This commit is contained in:
committed by
GitHub
parent
092cf881a5
commit
01c37dcdb5
@@ -94,7 +94,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
v, k, q, mask, layer_past, attention_mask, head_mask = inputs
|
||||
v, k, q, mask, layer_past, attention_mask, head_mask, use_cache = inputs
|
||||
batch_size = shape_list(q)[0]
|
||||
|
||||
q = self.Wq(q)
|
||||
@@ -104,11 +104,25 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
q = self.split_into_heads(q, batch_size)
|
||||
k = self.split_into_heads(k, batch_size)
|
||||
v = self.split_into_heads(v, batch_size)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = tf.unstack(layer_past, axis=0)
|
||||
k = tf.concat((past_key, k), axis=-2)
|
||||
v = tf.concat((past_value, v), axis=-2)
|
||||
present = tf.stack((k, v), axis=0)
|
||||
|
||||
# to cope with keras serialization
|
||||
# we need to cast `use_cache` to correct bool
|
||||
# if it is a tensor
|
||||
if tf.is_tensor(use_cache):
|
||||
if hasattr(use_cache, "numpy"):
|
||||
use_cache = bool(use_cache.numpy())
|
||||
else:
|
||||
use_cache = True
|
||||
|
||||
if use_cache is True:
|
||||
present = tf.stack((k, v), axis=0)
|
||||
else:
|
||||
present = (None,)
|
||||
|
||||
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
|
||||
scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])
|
||||
@@ -147,10 +161,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
||||
self.dropout2 = tf.keras.layers.Dropout(rate)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
x, mask, layer_past, attention_mask, head_mask = inputs
|
||||
x, mask, layer_past, attention_mask, head_mask, use_cache = inputs
|
||||
normed = self.layernorm1(x)
|
||||
attn_outputs = self.multi_head_attention(
|
||||
[normed, normed, normed, mask, layer_past, attention_mask, head_mask], training=training
|
||||
[normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache], training=training
|
||||
)
|
||||
attn_output = attn_outputs[0]
|
||||
attn_output = self.dropout1(attn_output, training=training)
|
||||
@@ -173,7 +187,6 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_past = config.output_past
|
||||
|
||||
self.d_model_size = config.n_embd
|
||||
self.num_layers = config.n_layer
|
||||
@@ -220,8 +233,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
training=False,
|
||||
):
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
past = inputs[1] if len(inputs) > 1 else past
|
||||
@@ -230,7 +245,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
use_cache = inputs[7] if len(inputs) > 7 else use_cache
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
past = inputs.get("past", past)
|
||||
@@ -239,10 +255,21 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
# If using past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if past is not None:
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:, -1:]
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1:]
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@@ -319,10 +346,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
for i, (h, layer_past) in enumerate(zip(self.h, past)):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i]], training=training)
|
||||
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache], training=training)
|
||||
hidden_states, present = outputs[:2]
|
||||
|
||||
if self.output_past:
|
||||
if use_cache is True:
|
||||
presents = presents + (present,)
|
||||
|
||||
if self.output_attentions:
|
||||
@@ -334,7 +361,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.output_past:
|
||||
if use_cache is True:
|
||||
outputs = outputs + (presents,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
@@ -386,6 +413,7 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
|
||||
|
||||
Indices can be obtained using :class:`transformers.CTRLTokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
@@ -394,8 +422,10 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
|
||||
should not be passed as input ids as they have already been computed.
|
||||
(see `past` output below). Can be used to speed up sequential decoding.
|
||||
If `past` is used, the user can optionally input only the last `input_ids`
|
||||
(those that don't have their past given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all `input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
attention_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
@@ -406,6 +436,7 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
corresponds to a `sentence B` token
|
||||
If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -421,6 +452,10 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
|
||||
use_cache (:obj:`bool`):
|
||||
If `use_cache` is True, `past` key value states are returned and
|
||||
can be used to speed up decoding (see `past`). Defaults to `True`.
|
||||
training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
|
||||
Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
|
||||
(if set to :obj:`False`) for evaluation.
|
||||
@@ -514,7 +549,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
|
||||
if past:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
|
||||
return {"inputs": inputs, "past": past}
|
||||
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
|
||||
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
||||
def call(self, inputs, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user