TF: XLA stable softmax (#16892)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Joao Gante
2022-04-25 20:10:51 +01:00
committed by GitHub
parent 8246caf3eb
commit e03966e404
49 changed files with 205 additions and 137 deletions

View File

@@ -53,7 +53,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...tf_utils import shape_list, stable_softmax
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
@@ -244,7 +244,7 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
@@ -1665,8 +1665,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings,
keras_serializable,
unpack_inputs,
); from ...tf_utils import (shape_list,
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
@@ -1855,7 +1855,7 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they