TF: XLA repetition penalty (#16879)
This commit is contained in:
@@ -241,18 +241,29 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
|
|||||||
|
|
||||||
self.penalty = penalty
|
self.penalty = penalty
|
||||||
|
|
||||||
def _create_score_penalties(self, input_ids, logits):
|
def _create_score_penalties(self, input_ids: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
|
||||||
# create logit penalties for already seen input_ids
|
# We want to populate the penalties in the positions of `input_ids`. Since XLA can't handle shapes unknown
|
||||||
token_penalties = np.ones(logits.shape)
|
# before runtime, `tf.unique` can't be used. Therefore, we may have redundant updates, when a given row has
|
||||||
prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
|
# the same token multiple times.
|
||||||
for i, prev_input_id in enumerate(prev_input_ids):
|
|
||||||
logit_penalized = logits[i].numpy()[prev_input_id]
|
# Gathers the penalties to apply
|
||||||
logit_penalties = np.zeros(logit_penalized.shape)
|
logit_penalties = tf.gather(logits, input_ids, axis=1, batch_dims=1)
|
||||||
# if previous logit score is < 0 then multiply repetition penalty else divide
|
logit_penalties = tf.where(logit_penalties > 0, 1 / self.penalty, logit_penalties)
|
||||||
logit_penalties[logit_penalized < 0] = self.penalty
|
logit_penalties = tf.where(logit_penalties < 0, self.penalty, logit_penalties)
|
||||||
logit_penalties[logit_penalized > 0] = 1 / self.penalty
|
|
||||||
np.put(token_penalties[i], prev_input_id, logit_penalties)
|
# Scatters the penalties
|
||||||
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
|
token_penalties = tf.ones(logits.shape)
|
||||||
|
indexable_prev_input_ids = tf.concat(
|
||||||
|
(
|
||||||
|
tf.expand_dims(tf.repeat(tf.range(input_ids.shape[0]), input_ids.shape[1]), axis=-1),
|
||||||
|
tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1),
|
||||||
|
),
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
token_penalties = tf.tensor_scatter_nd_update(
|
||||||
|
token_penalties, indices=indexable_prev_input_ids, updates=tf.reshape(logit_penalties, [-1])
|
||||||
|
)
|
||||||
|
return token_penalties
|
||||||
|
|
||||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||||
score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores)
|
score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores)
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
self.assertGreater(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_smooth[1, :]))
|
self.assertGreater(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_smooth[1, :]))
|
||||||
self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :]))
|
self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :]))
|
||||||
|
|
||||||
def test_repetition_penalty_dist_process(self):
|
def _get_repetition_penalty_inputs(self):
|
||||||
vocab_size = 10
|
vocab_size = 10
|
||||||
cur_len = 2
|
cur_len = 2
|
||||||
|
|
||||||
@@ -114,17 +114,31 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
scores = tf.where(mask, -1 / vocab_size, scores)
|
scores = tf.where(mask, -1 / vocab_size, scores)
|
||||||
mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool)
|
mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool)
|
||||||
scores = tf.where(mask, 4 / vocab_size, scores)
|
scores = tf.where(mask, 4 / vocab_size, scores)
|
||||||
|
return vocab_size, cur_len, input_ids, scores
|
||||||
|
|
||||||
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
def _check_repetition_penalty_outputs(self, scores, vocab_size):
|
||||||
|
# check that values were correctly changed (negative scores for used tokens should increase, others
|
||||||
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
|
# should decrease)
|
||||||
|
|
||||||
# check that values were correctly changed
|
|
||||||
self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2)
|
self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2)
|
||||||
self.assertAlmostEqual(scores[0, 1].numpy(), (1 / vocab_size) / 2)
|
self.assertAlmostEqual(scores[0, 1].numpy(), (1 / vocab_size) / 2)
|
||||||
|
self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change
|
||||||
|
|
||||||
self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2)
|
self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2)
|
||||||
self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2)
|
self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2)
|
||||||
|
self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change
|
||||||
|
|
||||||
|
def test_repetition_penalty_dist_process(self):
|
||||||
|
vocab_size, cur_len, input_ids, scores = self._get_repetition_penalty_inputs()
|
||||||
|
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||||
|
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
|
||||||
|
self._check_repetition_penalty_outputs(scores, vocab_size)
|
||||||
|
|
||||||
|
def test_repetition_penalty_dist_process_xla(self):
|
||||||
|
vocab_size, cur_len, input_ids, scores = self._get_repetition_penalty_inputs()
|
||||||
|
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||||
|
rep_penalty_proc = tf.function(rep_penalty_proc, jit_compile=True) # added line wrt non-XLA test
|
||||||
|
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
|
||||||
|
self._check_repetition_penalty_outputs(scores, vocab_size)
|
||||||
|
|
||||||
def test_top_k_dist_warper(self):
|
def test_top_k_dist_warper(self):
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|||||||
Reference in New Issue
Block a user