TF: T5 can now handle a padded past (i.e. XLA generation) (#17969)
* get the right slicing index for position_bias
This commit is contained in:
@@ -23,6 +23,7 @@ from typing import Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow.compiler.tf2xla.python.xla import dynamic_slice
|
||||||
|
|
||||||
from ...activations_tf import get_tf_activation
|
from ...activations_tf import get_tf_activation
|
||||||
from ...modeling_tf_outputs import (
|
from ...modeling_tf_outputs import (
|
||||||
@@ -384,10 +385,19 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
position_bias = self.compute_bias(real_seq_length, key_length)
|
position_bias = self.compute_bias(real_seq_length, key_length)
|
||||||
|
|
||||||
# if key and values are already calculated
|
# if key and values are already calculated we want only the last query position bias
|
||||||
# we want only the last query position bias
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
if not self.has_relative_attention_bias:
|
||||||
position_bias = position_bias[:, :, -seq_length:, :]
|
position_bias = position_bias[:, :, -seq_length:, :]
|
||||||
|
else:
|
||||||
|
# we might have a padded past structure, in which case we want to fetch the position bias slice
|
||||||
|
# right after the most recently filled past index
|
||||||
|
most_recently_filled_past_index = tf.reduce_max(tf.where(past_key_value[0][0, 0, :, 0] != 0.0))
|
||||||
|
position_bias = dynamic_slice(
|
||||||
|
position_bias,
|
||||||
|
(0, 0, most_recently_filled_past_index + 1, 0),
|
||||||
|
(1, self.n_heads, seq_length, real_seq_length),
|
||||||
|
)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
position_bias = tf.cast(position_bias, dtype=mask.dtype)
|
position_bias = tf.cast(position_bias, dtype=mask.dtype)
|
||||||
|
|||||||
@@ -590,21 +590,17 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
||||||
|
|
||||||
# xla_generate = tf.function(model.generate, jit_compile=True)
|
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||||
xla_generate = tf.function(model.generate)
|
|
||||||
|
|
||||||
# TODO (joao): there is something not quite right with XLA T5 -- as we increase `max_length` the two outputs
|
output_ids = model.generate(input_ids, num_beams=2)
|
||||||
# drift appart, where the XLA version clearly degrades its quality. XLA-related variables look fine (they are
|
output_ids_xla = xla_generate(input_ids, num_beams=2)
|
||||||
# being padded and filled in the right places). This also happens in other generation modes. Investigate.
|
|
||||||
output_ids = model.generate(input_ids, num_beams=2, max_length=9)
|
|
||||||
output_ids_xla = xla_generate(input_ids, num_beams=2, max_length=9)
|
|
||||||
|
|
||||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
|
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
|
||||||
|
|
||||||
expected_output_string = [
|
expected_output_string = [
|
||||||
"Aujourd'hui est une belle journée.",
|
"Aujourd'hui est une belle journée.",
|
||||||
"J'ai quatre chats,",
|
"J'ai quatre chats, trois chiens, deux oiseaux et un cheval.",
|
||||||
]
|
]
|
||||||
|
|
||||||
self.assertListEqual(expected_output_string, output_strings)
|
self.assertListEqual(expected_output_string, output_strings)
|
||||||
|
|||||||
Reference in New Issue
Block a user