TF: XLA stable softmax (#16892)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import GPT2Config, is_tf_available
|
||||
from transformers.testing_utils import get_gpu_count, require_tf, slow
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
@@ -536,8 +536,6 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
self.assertListEqual(output_strings, expected_output_string)
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
|
||||
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
|
||||
def test_lm_generate_gpt2_greedy_xla(self):
|
||||
# TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix
|
||||
# the underlying problem)
|
||||
@@ -563,30 +561,33 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
self.assertListEqual(output_strings, expected_output_strings)
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
|
||||
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
|
||||
def test_lm_generate_gpt2_sample_xla(self):
|
||||
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
|
||||
# output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible
|
||||
# and that we can seed both versions.
|
||||
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left"
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
sentence = ["The dog"]
|
||||
expected_output_string = [
|
||||
"The dog must be well educated to do anything. If anything, this must be her best friend"
|
||||
]
|
||||
expected_output_string_xla = ["The dog has been named in connection with the murder of a 20-year-old man in!"]
|
||||
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0])
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(output_strings, expected_output_string)
|
||||
sentence = ["The dog"]
|
||||
expected_output_string = [
|
||||
"The dog owner asked why did our vet decide there needed to be extra ventilation inside because most puppies"
|
||||
]
|
||||
expected_output_string_xla = [
|
||||
"The dog has been named in connection with the murder of a 20-year-old man in!"
|
||||
]
|
||||
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
|
||||
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0])
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(output_strings, expected_output_string_xla)
|
||||
output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0])
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(output_strings, expected_output_string)
|
||||
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0])
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(output_strings, expected_output_string_xla)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import T5Config, is_tf_available
|
||||
from transformers.testing_utils import get_gpu_count, require_sentencepiece, require_tf, require_tokenizers, slow
|
||||
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
@@ -481,8 +481,6 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
@require_tokenizers
|
||||
class TFT5GenerationIntegrationTests(unittest.TestCase):
|
||||
@slow
|
||||
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
|
||||
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
|
||||
def test_greedy_xla_generate_simple(self):
|
||||
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
@@ -534,30 +532,31 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
||||
self.assertListEqual(expected_output_string, output_strings)
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(not get_gpu_count(), "XLA not reliable on CPU")
|
||||
# TODO: remove the skip when the XLA CPU softmax issue gets sorted
|
||||
def test_sample_xla_generate_simple(self):
|
||||
# NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
|
||||
# output out of the same seed is far from guaranteed (unlike this example). We can, however, confirm that the
|
||||
# results are sensible and that we can seed both versions.
|
||||
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
# output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible
|
||||
# and that we can seed both versions.
|
||||
|
||||
sentence = "Translate English to German: I have two bananas"
|
||||
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
|
||||
expected_output_string = ["Ich habe 2 Bananen"]
|
||||
expected_output_string_xla = ["Ich habe 2 Bananen"]
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
# seed set -> deterministic sampling sequence -> deterministic generation
|
||||
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(expected_output_string, output_strings)
|
||||
sentence = "Translate English to German: I have two bananas"
|
||||
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
|
||||
expected_output_string = ["Ich habe zwei Bananen"]
|
||||
expected_output_string_xla = ["Ich habe 2 Bananen"]
|
||||
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
# seed set -> deterministic sampling sequence -> deterministic generation
|
||||
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
|
||||
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
|
||||
self.assertListEqual(expected_output_string_xla, output_strings_xla)
|
||||
# seed set -> deterministic sampling sequence -> deterministic generation
|
||||
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(expected_output_string, output_strings)
|
||||
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
# seed set -> deterministic sampling sequence -> deterministic generation
|
||||
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
|
||||
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
|
||||
self.assertListEqual(expected_output_string_xla, output_strings_xla)
|
||||
|
||||
@slow
|
||||
def test_sample_generate(self):
|
||||
|
||||
@@ -84,6 +84,7 @@ if is_tf_available():
|
||||
TFSampleEncoderDecoderOutput,
|
||||
)
|
||||
from transformers.modeling_tf_utils import unpack_inputs
|
||||
from transformers.tf_utils import stable_softmax
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
@@ -1709,6 +1710,41 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
self.assertFalse(output[3])
|
||||
self.assertFalse(output[4])
|
||||
|
||||
# Tests whether the stable softmax is stable on CPU, with and without XLA
|
||||
def test_xla_stable_softmax(self):
|
||||
large_penalty = -1e9
|
||||
n_tokens = 10
|
||||
batch_size = 8
|
||||
|
||||
def masked_softmax(x, boolean_mask):
|
||||
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
|
||||
masked_x = x + numerical_mask
|
||||
return stable_softmax(masked_x)
|
||||
|
||||
xla_masked_softmax = tf.function(masked_softmax, jit_compile=True)
|
||||
xla_stable_softmax = tf.function(stable_softmax, jit_compile=True)
|
||||
x = tf.random.normal((batch_size, n_tokens))
|
||||
|
||||
# Same outcome regardless of the boolean mask here
|
||||
masked_tokens = random.randint(0, n_tokens)
|
||||
boolean_mask = tf.convert_to_tensor([[1] * (n_tokens - masked_tokens) + [0] * masked_tokens], dtype=tf.int32)
|
||||
|
||||
# We can randomly mask a random numerical input OUTSIDE XLA
|
||||
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
|
||||
masked_x = x + numerical_mask
|
||||
xla_out = xla_stable_softmax(masked_x)
|
||||
out = stable_softmax(masked_x)
|
||||
assert tf.experimental.numpy.allclose(xla_out, out)
|
||||
|
||||
# The stable softmax has the same output as the original softmax
|
||||
unstable_out = tf.nn.softmax(masked_x)
|
||||
assert tf.experimental.numpy.allclose(unstable_out, out)
|
||||
|
||||
# We can randomly mask a random numerical input INSIDE XLA
|
||||
xla_out = xla_masked_softmax(x, boolean_mask)
|
||||
out = masked_softmax(x, boolean_mask)
|
||||
assert tf.experimental.numpy.allclose(xla_out, out)
|
||||
|
||||
|
||||
@require_tf
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user