TF: T5 can now handle a padded past (i.e. XLA generation) (#17969)

* get the right slicing index for position_bias
This commit is contained in:
Joao Gante
2022-07-04 19:47:43 +01:00
committed by GitHub
parent e3139ad301
commit f0982682bd
2 changed files with 17 additions and 11 deletions

View File

@@ -590,21 +590,17 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
# xla_generate = tf.function(model.generate, jit_compile=True)
xla_generate = tf.function(model.generate)
xla_generate = tf.function(model.generate, jit_compile=True)
# TODO (joao): there is something not quite right with XLA T5 -- as we increase `max_length` the two outputs
# drift appart, where the XLA version clearly degrades its quality. XLA-related variables look fine (they are
# being padded and filled in the right places). This also happens in other generation modes. Investigate.
output_ids = model.generate(input_ids, num_beams=2, max_length=9)
output_ids_xla = xla_generate(input_ids, num_beams=2, max_length=9)
output_ids = model.generate(input_ids, num_beams=2)
output_ids_xla = xla_generate(input_ids, num_beams=2)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
expected_output_string = [
"Aujourd'hui est une belle journée.",
"J'ai quatre chats,",
"J'ai quatre chats, trois chiens, deux oiseaux et un cheval.",
]
self.assertListEqual(expected_output_string, output_strings)