TF: XLA stable softmax (#16892)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -53,7 +53,7 @@ from ...modeling_tf_utils import (
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list
|
||||
from ...tf_utils import shape_list, stable_softmax
|
||||
from ...utils import logging
|
||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
|
||||
|
||||
@@ -244,7 +244,7 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
|
||||
attention_scores = tf.add(attention_scores, attention_mask)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
|
||||
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
@@ -1665,8 +1665,8 @@ from ...modeling_tf_utils import (
|
||||
TFWrappedEmbeddings,
|
||||
keras_serializable,
|
||||
unpack_inputs,
|
||||
); from ...tf_utils import (shape_list,
|
||||
)
|
||||
from ...tf_utils import shape_list, stable_softmax
|
||||
from ...utils import logging
|
||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
|
||||
|
||||
@@ -1855,7 +1855,7 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
|
||||
Reference in New Issue
Block a user