Improve BERT-like models performance with better self attention (#9124)

* Improve BERT-like models attention layers

* Apply style

* Put back error raising instead of assert

* Update template

* Fix copies

* Apply raising valueerror in MPNet

* Restore the copy check for the Intermediate layer in Longformer

* Update longformer
This commit is contained in:
Julien Plu
2020-12-21 13:10:15 +01:00
committed by GitHub
parent 6b034309ca
commit 5a8a4eb187
8 changed files with 348 additions and 271 deletions

View File

@@ -239,54 +239,58 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
f"of attention heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
assert config.hidden_size % config.num_attention_heads == 0
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.q = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="q"
self.q = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
name="q",
)
self.k = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="k"
self.k = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
name="k",
)
self.v = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="v"
self.v = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
name="v",
)
self.o = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="o"
self.o = tf.keras.layers.experimental.EinsumDense(
equation="abcd,cde->abe",
output_shape=(None, self.all_head_size),
bias_axes="e",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
name="o",
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, hidden_states, attention_mask, head_mask, output_attentions, position_bias=None, training=False):
batch_size = shape_list(hidden_states)[0]
q = self.q(hidden_states)
k = self.k(hidden_states)
v = self.v(hidden_states)
q = self.transpose_for_scores(q, batch_size)
k = self.transpose_for_scores(k, batch_size)
v = self.transpose_for_scores(v, batch_size)
attention_scores = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(shape_list(k)[-1], attention_scores.dtype)
attention_scores = attention_scores / tf.math.sqrt(dk)
dk = tf.cast(x=self.attention_head_size, dtype=q.dtype)
q = tf.multiply(x=q, y=tf.math.rsqrt(x=dk))
attention_scores = tf.einsum("aecd,abcd->acbe", k, q)
# Apply relative position embedding (precomputed in MPNetEncoder) if provided.
if position_bias is not None:
attention_scores += position_bias
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFMPNetModel call() function)
attention_scores = attention_scores + attention_mask
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
@@ -296,9 +300,7 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
if head_mask is not None:
attention_probs = attention_probs * head_mask
c = tf.matmul(attention_probs, v)
c = tf.transpose(c, perm=[0, 2, 1, 3])
c = tf.reshape(c, (batch_size, -1, self.all_head_size))
c = tf.einsum("acbe,aecd->abcd", attention_probs, v)
o = self.o(c)
outputs = (o, attention_probs) if output_attentions else (o,)
@@ -330,18 +332,22 @@ class TFMPNetIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
name="dense",
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
self.intermediate_act_fn = get_tf_activation(activation_string=config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states)
return hidden_states
@@ -351,16 +357,20 @@ class TFMPNetOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
bias_axes="d",
output_shape=(None, config.hidden_size),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states