TF XLA greedy generation (#15786)
* First attempt at TF XLA generation * Fix comments * Update XLA greedy generate with direct XLA calls * Support attention mask, prepare_inputs_for_generation no longer hardcoded for greedy * Handle position_ids correctly * make xla generate work for non xla case * force using xla generate * refactor * more fixes * finish cleaning * finish * finish * clean gpt2 tests * add gpt2 tests * correct more cases * up * finish * finish * more fixes * flake 8 stuff * final rag fix * Update src/transformers/models/rag/modeling_tf_rag.py * finish t5 as well * finish * Update src/transformers/generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -660,29 +660,16 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
else:
|
||||
model.gradient_checkpointing_disable()
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
|
||||
|
||||
# The dog
|
||||
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device)
|
||||
|
||||
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
|
||||
# fmt: off
|
||||
expected_output_ids = [
|
||||
464,
|
||||
3290,
|
||||
373,
|
||||
1043,
|
||||
287,
|
||||
257,
|
||||
2214,
|
||||
1474,
|
||||
262,
|
||||
16246,
|
||||
286,
|
||||
2688,
|
||||
290,
|
||||
2688,
|
||||
27262,
|
||||
13,
|
||||
198,
|
||||
198,
|
||||
464,
|
||||
3290,
|
||||
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
|
||||
464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290,
|
||||
]
|
||||
# fmt: on
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
if verify_outputs:
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
@@ -294,6 +294,21 @@ class TFGPT2ModelTester:
|
||||
result = model(inputs)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_gpt2_xla_generate(self, config, input_ids, *args):
|
||||
config.eos_token_id = None
|
||||
config.max_length = 10
|
||||
model = TFGPT2LMHeadModel(config=config)
|
||||
|
||||
# make sure there are no pad tokens in prompt
|
||||
input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id - 1)
|
||||
|
||||
generated = model.generate(input_ids)
|
||||
|
||||
generate_xla = tf.function(model.generate, jit_compile=True)
|
||||
generated_xla = generate_xla(input_ids)
|
||||
|
||||
self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())
|
||||
|
||||
def create_and_check_gpt2_double_head(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
|
||||
):
|
||||
@@ -393,6 +408,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt2_lm_head(*config_and_inputs)
|
||||
|
||||
def test_gpt2_xla_generate(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt2_xla_generate(*config_and_inputs)
|
||||
|
||||
def test_gpt2_double_head(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt2_double_head(*config_and_inputs)
|
||||
@@ -513,3 +532,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
# fmt: on
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_gpt2_xla(self):
|
||||
"""This test gives the exact same results as the non-xla test above"""
|
||||
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
||||
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
|
||||
|
||||
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
|
||||
# fmt: off
|
||||
expected_output_ids = [464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290]
|
||||
# fmt: on
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
|
||||
output_ids = xla_generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
Reference in New Issue
Block a user