TF: XLA stable softmax (#16892)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Joao Gante
2022-04-25 20:10:51 +01:00
committed by GitHub
parent 8246caf3eb
commit e03966e404
49 changed files with 205 additions and 137 deletions

View File

@@ -84,6 +84,7 @@ if is_tf_available():
TFSampleEncoderDecoderOutput,
)
from transformers.modeling_tf_utils import unpack_inputs
from transformers.tf_utils import stable_softmax
if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU")
@@ -1709,6 +1710,41 @@ class UtilsFunctionsTest(unittest.TestCase):
self.assertFalse(output[3])
self.assertFalse(output[4])
# Tests whether the stable softmax is stable on CPU, with and without XLA
def test_xla_stable_softmax(self):
large_penalty = -1e9
n_tokens = 10
batch_size = 8
def masked_softmax(x, boolean_mask):
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
masked_x = x + numerical_mask
return stable_softmax(masked_x)
xla_masked_softmax = tf.function(masked_softmax, jit_compile=True)
xla_stable_softmax = tf.function(stable_softmax, jit_compile=True)
x = tf.random.normal((batch_size, n_tokens))
# Same outcome regardless of the boolean mask here
masked_tokens = random.randint(0, n_tokens)
boolean_mask = tf.convert_to_tensor([[1] * (n_tokens - masked_tokens) + [0] * masked_tokens], dtype=tf.int32)
# We can randomly mask a random numerical input OUTSIDE XLA
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
masked_x = x + numerical_mask
xla_out = xla_stable_softmax(masked_x)
out = stable_softmax(masked_x)
assert tf.experimental.numpy.allclose(xla_out, out)
# The stable softmax has the same output as the original softmax
unstable_out = tf.nn.softmax(masked_x)
assert tf.experimental.numpy.allclose(unstable_out, out)
# We can randomly mask a random numerical input INSIDE XLA
xla_out = xla_masked_softmax(x, boolean_mask)
out = masked_softmax(x, boolean_mask)
assert tf.experimental.numpy.allclose(xla_out, out)
@require_tf
@is_staging_test