TF: XLA stable softmax (#16892)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Joao Gante
2022-04-25 20:10:51 +01:00
committed by GitHub
parent 8246caf3eb
commit e03966e404
49 changed files with 205 additions and 137 deletions

View File

@@ -16,7 +16,7 @@
import unittest
from transformers import GPT2Config, is_tf_available
from transformers.testing_utils import get_gpu_count, require_tf, slow
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
@@ -536,8 +536,6 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
self.assertListEqual(output_strings, expected_output_string)
@slow
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_lm_generate_gpt2_greedy_xla(self):
# TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix
# the underlying problem)
@@ -563,30 +561,33 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
self.assertListEqual(output_strings, expected_output_strings)
@slow
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_lm_generate_gpt2_sample_xla(self):
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
# output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible
# and that we can seed both versions.
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
sentence = ["The dog"]
expected_output_string = [
"The dog must be well educated to do anything. If anything, this must be her best friend"
]
expected_output_string_xla = ["The dog has been named in connection with the murder of a 20-year-old man in!"]
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string)
sentence = ["The dog"]
expected_output_string = [
"The dog owner asked why did our vet decide there needed to be extra ventilation inside because most puppies"
]
expected_output_string_xla = [
"The dog has been named in connection with the murder of a 20-year-old man in!"
]
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string_xla)
output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string)
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string_xla)

View File

@@ -16,7 +16,7 @@
import unittest
from transformers import T5Config, is_tf_available
from transformers.testing_utils import get_gpu_count, require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.utils import cached_property
from ..test_configuration_common import ConfigTester
@@ -481,8 +481,6 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
@require_tokenizers
class TFT5GenerationIntegrationTests(unittest.TestCase):
@slow
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_greedy_xla_generate_simple(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
@@ -534,30 +532,31 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(expected_output_string, output_strings)
@slow
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
def test_sample_xla_generate_simple(self):
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
# output out of the same seed is far from guaranteed (unlike this example). We can, however, confirm that the
# results are sensible and that we can seed both versions.
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
# output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible
# and that we can seed both versions.
sentence = "Translate English to German: I have two bananas"
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
expected_output_string = ["Ich habe 2 Bananen"]
expected_output_string_xla = ["Ich habe 2 Bananen"]
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(expected_output_string, output_strings)
sentence = "Translate English to German: I have two bananas"
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
expected_output_string = ["Ich habe zwei Bananen"]
expected_output_string_xla = ["Ich habe 2 Bananen"]
xla_generate = tf.function(model.generate, jit_compile=True)
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
self.assertListEqual(expected_output_string_xla, output_strings_xla)
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(expected_output_string, output_strings)
xla_generate = tf.function(model.generate, jit_compile=True)
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
self.assertListEqual(expected_output_string_xla, output_strings_xla)
@slow
def test_sample_generate(self):

View File

@@ -84,6 +84,7 @@ if is_tf_available():
TFSampleEncoderDecoderOutput,
)
from transformers.modeling_tf_utils import unpack_inputs
from transformers.tf_utils import stable_softmax
if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU")
@@ -1709,6 +1710,41 @@ class UtilsFunctionsTest(unittest.TestCase):
self.assertFalse(output[3])
self.assertFalse(output[4])
# Tests whether the stable softmax is stable on CPU, with and without XLA
def test_xla_stable_softmax(self):
large_penalty = -1e9
n_tokens = 10
batch_size = 8
def masked_softmax(x, boolean_mask):
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
masked_x = x + numerical_mask
return stable_softmax(masked_x)
xla_masked_softmax = tf.function(masked_softmax, jit_compile=True)
xla_stable_softmax = tf.function(stable_softmax, jit_compile=True)
x = tf.random.normal((batch_size, n_tokens))
# Same outcome regardless of the boolean mask here
masked_tokens = random.randint(0, n_tokens)
boolean_mask = tf.convert_to_tensor([[1] * (n_tokens - masked_tokens) + [0] * masked_tokens], dtype=tf.int32)
# We can randomly mask a random numerical input OUTSIDE XLA
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
masked_x = x + numerical_mask
xla_out = xla_stable_softmax(masked_x)
out = stable_softmax(masked_x)
assert tf.experimental.numpy.allclose(xla_out, out)
# The stable softmax has the same output as the original softmax
unstable_out = tf.nn.softmax(masked_x)
assert tf.experimental.numpy.allclose(unstable_out, out)
# We can randomly mask a random numerical input INSIDE XLA
xla_out = xla_masked_softmax(x, boolean_mask)
out = masked_softmax(x, boolean_mask)
assert tf.experimental.numpy.allclose(xla_out, out)
@require_tf
@is_staging_test