TF generate refactor - XLA sample (#16713)
This commit is contained in:
@@ -447,19 +447,6 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
|
||||
|
||||
@require_tf
|
||||
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_distilgpt2(self):
|
||||
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||
input_ids = tf.convert_to_tensor([[464, 1893]], dtype=tf.int32) # The president
|
||||
|
||||
# The president of the United States, and the president of the United Kingdom, have been in the White
|
||||
# fmt: off
|
||||
expected_output_ids = [464, 1893, 286, 262, 1578, 1829, 11, 290, 262, 1893, 286, 262, 1578, 7526, 11, 423, 587, 287, 262, 2635]
|
||||
# fmt: on
|
||||
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_greedy_distilgpt2_batch_special(self):
|
||||
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||
@@ -506,18 +493,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
"temperature": 1.5,
|
||||
"top_k": 500,
|
||||
"top_p": 0.9,
|
||||
"seed": [42, 0], # seed set -> deterministic sampling sequence -> deterministic generation
|
||||
}
|
||||
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
expected_output_string = [
|
||||
"Today is a beautiful day and this makes finding holiday travel easier for you to do other project\nOh",
|
||||
"Yesterday was an enjoyable but especially great note though it certainly upset many Democrats who say",
|
||||
"Today is a beautiful day and we will make you feel very hot/terrific in all",
|
||||
"Yesterday was another solid success as news coverage became standard American domestic television hit.",
|
||||
]
|
||||
self.assertListEqual(output_strings, expected_output_string)
|
||||
|
||||
@@ -561,7 +548,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_gpt2_xla(self):
|
||||
def test_lm_generate_gpt2_xla_greedy(self):
|
||||
"""This test gives the exact same results as the non-xla test above"""
|
||||
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
||||
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
|
||||
@@ -574,3 +561,16 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
output_ids = xla_generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_gpt2_xla_sample(self):
|
||||
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
||||
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
|
||||
|
||||
# fmt: off
|
||||
expected_output_ids = [464, 3290, 550, 284, 307, 4376, 287, 281, 4044, 1363, 329, 734, 812, 878, 852, 4376, 757, 329, 2267, 0]
|
||||
# fmt: on
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
|
||||
output_ids = xla_generate(input_ids, do_sample=True, seed=[42, 0])
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@@ -524,6 +524,35 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertListEqual(expected_output_string, output_strings)
|
||||
|
||||
@slow
|
||||
def test_sample_xla_generate_simple(self):
|
||||
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
sentence = "Translate English to German: Today is a beautiful day."
|
||||
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
|
||||
# XLA reorder ops, which causes operations like FP matmul to have slightly different results, causing
|
||||
# divergences in generate -- especially with sampling.
|
||||
expected_output_string = ["Heute ist ein schöner Tag."]
|
||||
expected_output_string_xla = ["Heute ist ein schöne Tage."]
|
||||
# However, notice that the first tokens are the same, for the same seed
|
||||
assert expected_output_string[0][:15] == expected_output_string_xla[0][:15]
|
||||
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
# seed set -> deterministic sampling sequence -> deterministic generation
|
||||
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(expected_output_string, output_strings)
|
||||
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
# seed set -> deterministic sampling sequence -> deterministic generation
|
||||
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
|
||||
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
|
||||
self.assertListEqual(expected_output_string_xla, output_strings_xla)
|
||||
|
||||
@slow
|
||||
def test_sample_generate(self):
|
||||
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
@@ -540,16 +569,16 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
||||
"temperature": 0.8,
|
||||
"top_k": 500,
|
||||
"top_p": 0.9,
|
||||
"seed": [20, 0], # seed set -> deterministic sampling sequence -> deterministic generation
|
||||
}
|
||||
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
expected_output_string = ["i love her I really love my heart", "die Transformatoren sind wirklich erstaunlich"]
|
||||
expected_output_string = ["- I really love my way of this.", "die Transformatoren sind wirklich erstaunlich"]
|
||||
|
||||
self.assertListEqual(expected_output_string, output_strings)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user