[All models] Extend config.output_attentions with output_attentions function arguments (#4538)

* DOC: Replace instances of ``config.output_attentions`` with function argument ``output_attentions``

* DOC: Apply Black Formatting

* Fix errors where output_attentions was undefined

* Remove output_attentions in classes per review

* Fix regressions on tests having `output_attention`

* Fix further regressions in tests relating to `output_attentions`

Ensure proper propagation of `output_attentions` as a function parameter
to all model subclasses

* Fix more regressions in `test_output_attentions`

* Fix issues with BertEncoder

* Rename related variables to `output_attentions`

* fix pytorch tests

* fix bert and gpt2 tf

* Fix most TF tests for `test_output_attentions`

* Fix linter errors and more TF tests

* fix conflicts

* DOC: Apply Black Formatting

* Fix errors where output_attentions was undefined

* Remove output_attentions in classes per review

* Fix regressions on tests having `output_attention`

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* fix pytorch tests

* fix conflicts

* fix conflicts

* Fix linter errors and more TF tests

* fix tf tests

* make style

* fix isort

* improve output_attentions

* improve tensorflow

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Bharat Raghunathan
2020-06-10 03:09:06 +05:30
committed by GitHub
parent f90bc44d9a
commit 6e603cb789
38 changed files with 1108 additions and 549 deletions

View File

@@ -23,7 +23,13 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
from .modeling_tf_utils import (
TFPreTrainedModel,
TFSharedEmbeddings,
cast_bool_to_primitive,
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
@@ -78,9 +84,8 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
class TFMultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model_size, num_heads, output_attentions=False, **kwargs):
def __init__(self, d_model_size, num_heads, **kwargs):
super().__init__(**kwargs)
self.output_attentions = output_attentions
self.num_heads = num_heads
self.d_model_size = d_model_size
@@ -97,7 +102,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=False):
v, k, q, mask, layer_past, attention_mask, head_mask, use_cache = inputs
v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
batch_size = shape_list(q)[0]
q = self.Wq(q)
@@ -114,13 +119,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
v = tf.concat((past_value, v), axis=-2)
# to cope with keras serialization
# we need to cast `use_cache` to correct bool
# if it is a tensor
if tf.is_tensor(use_cache):
if hasattr(use_cache, "numpy"):
use_cache = bool(use_cache.numpy())
else:
use_cache = True
use_cache = cast_bool_to_primitive(use_cache, True)
if use_cache is True:
present = tf.stack((k, v), axis=0)
@@ -134,7 +133,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
output = self.dense(original_size_attention)
outputs = (output, present)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (attn,)
return outputs
@@ -147,14 +146,10 @@ def point_wise_feed_forward_network(d_model_size, dff, name=""):
class TFEncoderLayer(tf.keras.layers.Layer):
def __init__(
self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, output_attentions=False, **kwargs
):
def __init__(self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, **kwargs):
super().__init__(**kwargs)
self.multi_head_attention = TFMultiHeadAttention(
d_model_size, num_heads, output_attentions, name="multi_head_attention"
)
self.multi_head_attention = TFMultiHeadAttention(d_model_size, num_heads, name="multi_head_attention")
self.ffn = point_wise_feed_forward_network(d_model_size, dff, name="ffn")
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
@@ -164,10 +159,11 @@ class TFEncoderLayer(tf.keras.layers.Layer):
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, inputs, training=False):
x, mask, layer_past, attention_mask, head_mask, use_cache = inputs
x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention(
[normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache], training=training
[normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions],
training=training,
)
attn_output = attn_outputs[0]
attn_output = self.dropout1(attn_output, training=training)
@@ -208,7 +204,6 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
config.dff,
config.resid_pdrop,
config.layer_norm_epsilon,
config.output_attentions,
name="h_._{}".format(i),
)
for i in range(config.n_layer)
@@ -237,6 +232,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
head_mask=None,
inputs_embeds=None,
use_cache=True,
output_attentions=None,
training=False,
):
@@ -249,7 +245,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
use_cache = inputs[7] if len(inputs) > 7 else use_cache
assert len(inputs) <= 8, "Too many inputs."
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
@@ -259,10 +256,13 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 8, "Too many inputs."
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
# If using past key value states, only the last tokens
# should be given as an input
if past is not None:
@@ -349,13 +349,16 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
for i, (h, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache], training=training)
outputs = h(
[hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache, output_attentions],
training=training,
)
hidden_states, present = outputs[:2]
if use_cache is True:
presents = presents + (present,)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
all_attentions.append(outputs[2])
hidden_states = self.layernorm(hidden_states)
@@ -368,7 +371,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
if cast_bool_to_primitive(output_attentions) is True:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
@@ -489,7 +492,7 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
@@ -569,7 +572,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.