[All models] Extend config.output_attentions with output_attentions function arguments (#4538)
* DOC: Replace instances of ``config.output_attentions`` with function argument ``output_attentions`` * DOC: Apply Black Formatting * Fix errors where output_attentions was undefined * Remove output_attentions in classes per review * Fix regressions on tests having `output_attention` * Fix further regressions in tests relating to `output_attentions` Ensure proper propagation of `output_attentions` as a function parameter to all model subclasses * Fix more regressions in `test_output_attentions` * Fix issues with BertEncoder * Rename related variables to `output_attentions` * fix pytorch tests * fix bert and gpt2 tf * Fix most TF tests for `test_output_attentions` * Fix linter errors and more TF tests * fix conflicts * DOC: Apply Black Formatting * Fix errors where output_attentions was undefined * Remove output_attentions in classes per review * Fix regressions on tests having `output_attention` * fix conflicts * fix conflicts * fix conflicts * fix conflicts * fix pytorch tests * fix conflicts * fix conflicts * Fix linter errors and more TF tests * fix tf tests * make style * fix isort * improve output_attentions * improve tensorflow Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f90bc44d9a
commit
6e603cb789
@@ -23,7 +23,13 @@ import tensorflow as tf
|
||||
|
||||
from .configuration_ctrl import CTRLConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
|
||||
from .modeling_tf_utils import (
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
cast_bool_to_primitive,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from .tokenization_utils import BatchEncoding
|
||||
|
||||
|
||||
@@ -78,9 +84,8 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
|
||||
|
||||
|
||||
class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, d_model_size, num_heads, output_attentions=False, **kwargs):
|
||||
def __init__(self, d_model_size, num_heads, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.output_attentions = output_attentions
|
||||
self.num_heads = num_heads
|
||||
self.d_model_size = d_model_size
|
||||
|
||||
@@ -97,7 +102,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
v, k, q, mask, layer_past, attention_mask, head_mask, use_cache = inputs
|
||||
v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
|
||||
batch_size = shape_list(q)[0]
|
||||
|
||||
q = self.Wq(q)
|
||||
@@ -114,13 +119,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
v = tf.concat((past_value, v), axis=-2)
|
||||
|
||||
# to cope with keras serialization
|
||||
# we need to cast `use_cache` to correct bool
|
||||
# if it is a tensor
|
||||
if tf.is_tensor(use_cache):
|
||||
if hasattr(use_cache, "numpy"):
|
||||
use_cache = bool(use_cache.numpy())
|
||||
else:
|
||||
use_cache = True
|
||||
use_cache = cast_bool_to_primitive(use_cache, True)
|
||||
|
||||
if use_cache is True:
|
||||
present = tf.stack((k, v), axis=0)
|
||||
@@ -134,7 +133,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
output = self.dense(original_size_attention)
|
||||
|
||||
outputs = (output, present)
|
||||
if self.output_attentions:
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
outputs = outputs + (attn,)
|
||||
return outputs
|
||||
|
||||
@@ -147,14 +146,10 @@ def point_wise_feed_forward_network(d_model_size, dff, name=""):
|
||||
|
||||
|
||||
class TFEncoderLayer(tf.keras.layers.Layer):
|
||||
def __init__(
|
||||
self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, output_attentions=False, **kwargs
|
||||
):
|
||||
def __init__(self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.multi_head_attention = TFMultiHeadAttention(
|
||||
d_model_size, num_heads, output_attentions, name="multi_head_attention"
|
||||
)
|
||||
self.multi_head_attention = TFMultiHeadAttention(d_model_size, num_heads, name="multi_head_attention")
|
||||
self.ffn = point_wise_feed_forward_network(d_model_size, dff, name="ffn")
|
||||
|
||||
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
|
||||
@@ -164,10 +159,11 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
||||
self.dropout2 = tf.keras.layers.Dropout(rate)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
x, mask, layer_past, attention_mask, head_mask, use_cache = inputs
|
||||
x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
|
||||
normed = self.layernorm1(x)
|
||||
attn_outputs = self.multi_head_attention(
|
||||
[normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache], training=training
|
||||
[normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions],
|
||||
training=training,
|
||||
)
|
||||
attn_output = attn_outputs[0]
|
||||
attn_output = self.dropout1(attn_output, training=training)
|
||||
@@ -208,7 +204,6 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
config.dff,
|
||||
config.resid_pdrop,
|
||||
config.layer_norm_epsilon,
|
||||
config.output_attentions,
|
||||
name="h_._{}".format(i),
|
||||
)
|
||||
for i in range(config.n_layer)
|
||||
@@ -237,6 +232,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
output_attentions=None,
|
||||
training=False,
|
||||
):
|
||||
|
||||
@@ -249,7 +245,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
use_cache = inputs[7] if len(inputs) > 7 else use_cache
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
past = inputs.get("past", past)
|
||||
@@ -259,10 +256,13 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
|
||||
# If using past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if past is not None:
|
||||
@@ -349,13 +349,16 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
for i, (h, layer_past) in enumerate(zip(self.h, past)):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache], training=training)
|
||||
outputs = h(
|
||||
[hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache, output_attentions],
|
||||
training=training,
|
||||
)
|
||||
hidden_states, present = outputs[:2]
|
||||
|
||||
if use_cache is True:
|
||||
presents = presents + (present,)
|
||||
|
||||
if self.output_attentions:
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
all_attentions.append(outputs[2])
|
||||
|
||||
hidden_states = self.layernorm(hidden_states)
|
||||
@@ -368,7 +371,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
outputs = outputs + (presents,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
if cast_bool_to_primitive(output_attentions) is True:
|
||||
# let the number of heads free (-1) so we can extract attention even after head pruning
|
||||
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
||||
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
||||
@@ -489,7 +492,7 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
@@ -569,7 +572,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user