TF XLA greedy generation (#15786)

* First attempt at TF XLA generation

* Fix comments

* Update XLA greedy generate with direct XLA calls

* Support attention mask, prepare_inputs_for_generation no longer hardcoded for greedy

* Handle position_ids correctly

* make xla generate work for non xla case

* force using xla generate

* refactor

* more fixes

* finish cleaning

* finish

* finish

* clean gpt2 tests

* add gpt2 tests

* correct more cases

* up

* finish

* finish

* more fixes

* flake 8 stuff

* final rag fix

* Update src/transformers/models/rag/modeling_tf_rag.py

* finish t5 as well

* finish

* Update src/transformers/generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Matt
2022-03-15 13:19:20 +00:00
committed by GitHub
parent e5bc438cc8
commit cd4c5c9060
8 changed files with 370 additions and 93 deletions

View File

@@ -660,29 +660,16 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
else:
model.gradient_checkpointing_disable()
model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
# The dog
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device)
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
# fmt: off
expected_output_ids = [
464,
3290,
373,
1043,
287,
257,
2214,
1474,
262,
16246,
286,
2688,
290,
2688,
27262,
13,
198,
198,
464,
3290,
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290,
]
# fmt: on
output_ids = model.generate(input_ids, do_sample=False)
if verify_outputs:
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)

View File

@@ -294,6 +294,21 @@ class TFGPT2ModelTester:
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_gpt2_xla_generate(self, config, input_ids, *args):
config.eos_token_id = None
config.max_length = 10
model = TFGPT2LMHeadModel(config=config)
# make sure there are no pad tokens in prompt
input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id - 1)
generated = model.generate(input_ids)
generate_xla = tf.function(model.generate, jit_compile=True)
generated_xla = generate_xla(input_ids)
self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())
def create_and_check_gpt2_double_head(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
):
@@ -393,6 +408,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_lm_head(*config_and_inputs)
def test_gpt2_xla_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_xla_generate(*config_and_inputs)
def test_gpt2_double_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_double_head(*config_and_inputs)
@@ -513,3 +532,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
# fmt: on
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow
def test_lm_generate_gpt2_xla(self):
"""This test gives the exact same results as the non-xla test above"""
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
# fmt: off
expected_output_ids = [464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290]
# fmt: on
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)

View File

@@ -227,6 +227,23 @@ class TFT5ModelTester:
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_t5_xla_generate(self, config, input_ids, *args):
config.eos_token_id = None
config.max_length = 10
config.do_sample = False
config.num_beams = 1
model = TFT5ForConditionalGeneration(config=config)
# make sure there are no pad tokens in prompt
input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id + 5)
generated = model.generate(input_ids)
generate_xla = tf.function(model.generate, jit_compile=True)
generated_xla = generate_xla(input_ids)
self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, token_labels) = config_and_inputs
@@ -280,6 +297,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs)
def test_t5_model_xla_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_xla_generate(*config_and_inputs)
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -454,6 +475,27 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
@require_sentencepiece
@require_tokenizers
class TFT5GenerationIntegrationTests(unittest.TestCase):
@slow
def test_greedy_xla_generate_simple(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
sentence = "Translate English to German: Today is a beautiful day."
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = model.generate(input_ids)
output_ids_xla = xla_generate(input_ids)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
expected_output_string = ["Heute ist ein schöner Tag."]
self.assertListEqual(expected_output_string, output_strings)
self.assertListEqual(expected_output_string, output_strings_xla)
@slow
def test_greedy_generate(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")