Improve BERT-like models performance with better self attention (#9124)
* Improve BERT-like models attention layers * Apply style * Put back error raising instead of assert * Update template * Fix copies * Apply raising valueerror in MPNet * Restore the copy check for the Intermediate layer in Longformer * Update longformer
This commit is contained in:
@@ -239,54 +239,58 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
|
||||
f"of attention heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.q = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="q"
|
||||
self.q = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cde->abde",
|
||||
output_shape=(None, config.num_attention_heads, self.attention_head_size),
|
||||
bias_axes="de",
|
||||
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
|
||||
name="q",
|
||||
)
|
||||
self.k = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="k"
|
||||
self.k = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cde->abde",
|
||||
output_shape=(None, config.num_attention_heads, self.attention_head_size),
|
||||
bias_axes="de",
|
||||
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
|
||||
name="k",
|
||||
)
|
||||
self.v = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="v"
|
||||
self.v = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cde->abde",
|
||||
output_shape=(None, config.num_attention_heads, self.attention_head_size),
|
||||
bias_axes="de",
|
||||
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
|
||||
name="v",
|
||||
)
|
||||
self.o = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="o"
|
||||
self.o = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abcd,cde->abe",
|
||||
output_shape=(None, self.all_head_size),
|
||||
bias_axes="e",
|
||||
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
|
||||
name="o",
|
||||
)
|
||||
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x, batch_size):
|
||||
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||
|
||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, position_bias=None, training=False):
|
||||
batch_size = shape_list(hidden_states)[0]
|
||||
|
||||
q = self.q(hidden_states)
|
||||
k = self.k(hidden_states)
|
||||
v = self.v(hidden_states)
|
||||
|
||||
q = self.transpose_for_scores(q, batch_size)
|
||||
k = self.transpose_for_scores(k, batch_size)
|
||||
v = self.transpose_for_scores(v, batch_size)
|
||||
|
||||
attention_scores = tf.matmul(q, k, transpose_b=True)
|
||||
dk = tf.cast(shape_list(k)[-1], attention_scores.dtype)
|
||||
attention_scores = attention_scores / tf.math.sqrt(dk)
|
||||
dk = tf.cast(x=self.attention_head_size, dtype=q.dtype)
|
||||
q = tf.multiply(x=q, y=tf.math.rsqrt(x=dk))
|
||||
attention_scores = tf.einsum("aecd,abcd->acbe", k, q)
|
||||
|
||||
# Apply relative position embedding (precomputed in MPNetEncoder) if provided.
|
||||
if position_bias is not None:
|
||||
attention_scores += position_bias
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in TFMPNetModel call() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
|
||||
@@ -296,9 +300,7 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
c = tf.matmul(attention_probs, v)
|
||||
c = tf.transpose(c, perm=[0, 2, 1, 3])
|
||||
c = tf.reshape(c, (batch_size, -1, self.all_head_size))
|
||||
c = tf.einsum("acbe,aecd->abcd", attention_probs, v)
|
||||
o = self.o(c)
|
||||
|
||||
outputs = (o, attention_probs) if output_attentions else (o,)
|
||||
@@ -330,18 +332,22 @@ class TFMPNetIntermediate(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
self.dense = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cd->abd",
|
||||
output_shape=(None, config.intermediate_size),
|
||||
bias_axes="d",
|
||||
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
|
||||
name="dense",
|
||||
)
|
||||
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
||||
self.intermediate_act_fn = get_tf_activation(activation_string=config.hidden_act)
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
hidden_states = self.dense(inputs=hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(inputs=hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -351,16 +357,20 @@ class TFMPNetOutput(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
self.dense = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cd->abd",
|
||||
bias_axes="d",
|
||||
output_shape=(None, config.hidden_size),
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="dense",
|
||||
)
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||
|
||||
def call(self, hidden_states, input_tensor, training=False):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
hidden_states = self.dense(inputs=hidden_states)
|
||||
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
||||
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user