TF generate refactor - XLA sample (#16713)

This commit is contained in:
Joao Gante
2022-04-18 10:58:24 +01:00
committed by GitHub
parent 02de7a8e7f
commit b4ddd2677c
3 changed files with 187 additions and 88 deletions

View File

@@ -447,19 +447,6 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
@require_tf
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_distilgpt2(self):
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
input_ids = tf.convert_to_tensor([[464, 1893]], dtype=tf.int32) # The president
# The president of the United States, and the president of the United Kingdom, have been in the White
# fmt: off
expected_output_ids = [464, 1893, 286, 262, 1578, 1829, 11, 290, 262, 1893, 286, 262, 1578, 7526, 11, 423, 587, 287, 262, 2635]
# fmt: on
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow
def test_lm_generate_greedy_distilgpt2_batch_special(self):
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
@@ -506,18 +493,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
"temperature": 1.5,
"top_k": 500,
"top_p": 0.9,
"seed": [42, 0], # seed set -> deterministic sampling sequence -> deterministic generation
}
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [
"Today is a beautiful day and this makes finding holiday travel easier for you to do other project\nOh",
"Yesterday was an enjoyable but especially great note though it certainly upset many Democrats who say",
"Today is a beautiful day and we will make you feel very hot/terrific in all",
"Yesterday was another solid success as news coverage became standard American domestic television hit.",
]
self.assertListEqual(output_strings, expected_output_string)
@@ -561,7 +548,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow
def test_lm_generate_gpt2_xla(self):
def test_lm_generate_gpt2_xla_greedy(self):
"""This test gives the exact same results as the non-xla test above"""
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
@@ -574,3 +561,16 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
output_ids = xla_generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow
def test_lm_generate_gpt2_xla_sample(self):
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
# fmt: off
expected_output_ids = [464, 3290, 550, 284, 307, 4376, 287, 281, 4044, 1363, 329, 734, 812, 878, 852, 4376, 757, 329, 2267, 0]
# fmt: on
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=True, seed=[42, 0])
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)

View File

@@ -524,6 +524,35 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(expected_output_string, output_strings)
@slow
def test_sample_xla_generate_simple(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
sentence = "Translate English to German: Today is a beautiful day."
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
# XLA reorder ops, which causes operations like FP matmul to have slightly different results, causing
# divergences in generate -- especially with sampling.
expected_output_string = ["Heute ist ein schöner Tag."]
expected_output_string_xla = ["Heute ist ein schöne Tage."]
# However, notice that the first tokens are the same, for the same seed
assert expected_output_string[0][:15] == expected_output_string_xla[0][:15]
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(expected_output_string, output_strings)
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
xla_generate = tf.function(model.generate, jit_compile=True)
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
self.assertListEqual(expected_output_string_xla, output_strings_xla)
@slow
def test_sample_generate(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
@@ -540,16 +569,16 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
"temperature": 0.8,
"top_k": 500,
"top_p": 0.9,
"seed": [20, 0], # seed set -> deterministic sampling sequence -> deterministic generation
}
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = ["i love her I really love my heart", "die Transformatoren sind wirklich erstaunlich"]
expected_output_string = ["- I really love my way of this.", "die Transformatoren sind wirklich erstaunlich"]
self.assertListEqual(expected_output_string, output_strings)