From 99c8226b12b7e11764e7fb22b2aa431c49a58a98 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 22 Apr 2022 18:29:32 +0100 Subject: [PATCH] TF: XLA repetition penalty (#16879) --- .../generation_tf_logits_process.py | 35 ++++++++++++------- .../test_generation_tf_logits_process.py | 26 ++++++++++---- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/src/transformers/generation_tf_logits_process.py b/src/transformers/generation_tf_logits_process.py index 48df31a3ff..eefd1f0ace 100644 --- a/src/transformers/generation_tf_logits_process.py +++ b/src/transformers/generation_tf_logits_process.py @@ -241,18 +241,29 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor): self.penalty = penalty - def _create_score_penalties(self, input_ids, logits): - # create logit penalties for already seen input_ids - token_penalties = np.ones(logits.shape) - prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()] - for i, prev_input_id in enumerate(prev_input_ids): - logit_penalized = logits[i].numpy()[prev_input_id] - logit_penalties = np.zeros(logit_penalized.shape) - # if previous logit score is < 0 then multiply repetition penalty else divide - logit_penalties[logit_penalized < 0] = self.penalty - logit_penalties[logit_penalized > 0] = 1 / self.penalty - np.put(token_penalties[i], prev_input_id, logit_penalties) - return tf.convert_to_tensor(token_penalties, dtype=tf.float32) + def _create_score_penalties(self, input_ids: tf.Tensor, logits: tf.Tensor) -> tf.Tensor: + # We want to populate the penalties in the positions of `input_ids`. Since XLA can't handle shapes unknown + # before runtime, `tf.unique` can't be used. Therefore, we may have redundant updates, when a given row has + # the same token multiple times. + + # Gathers the penalties to apply + logit_penalties = tf.gather(logits, input_ids, axis=1, batch_dims=1) + logit_penalties = tf.where(logit_penalties > 0, 1 / self.penalty, logit_penalties) + logit_penalties = tf.where(logit_penalties < 0, self.penalty, logit_penalties) + + # Scatters the penalties + token_penalties = tf.ones(logits.shape) + indexable_prev_input_ids = tf.concat( + ( + tf.expand_dims(tf.repeat(tf.range(input_ids.shape[0]), input_ids.shape[1]), axis=-1), + tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1), + ), + axis=1, + ) + token_penalties = tf.tensor_scatter_nd_update( + token_penalties, indices=indexable_prev_input_ids, updates=tf.reshape(logit_penalties, [-1]) + ) + return token_penalties def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores) diff --git a/tests/generation/test_generation_tf_logits_process.py b/tests/generation/test_generation_tf_logits_process.py index 60fa3352e5..913b26cb64 100644 --- a/tests/generation/test_generation_tf_logits_process.py +++ b/tests/generation/test_generation_tf_logits_process.py @@ -101,7 +101,7 @@ class TFLogitsProcessorTest(unittest.TestCase): self.assertGreater(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_smooth[1, :])) self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :])) - def test_repetition_penalty_dist_process(self): + def _get_repetition_penalty_inputs(self): vocab_size = 10 cur_len = 2 @@ -114,17 +114,31 @@ class TFLogitsProcessorTest(unittest.TestCase): scores = tf.where(mask, -1 / vocab_size, scores) mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool) scores = tf.where(mask, 4 / vocab_size, scores) + return vocab_size, cur_len, input_ids, scores - rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) - - scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len) - - # check that values were correctly changed + def _check_repetition_penalty_outputs(self, scores, vocab_size): + # check that values were correctly changed (negative scores for used tokens should increase, others + # should decrease) self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2) self.assertAlmostEqual(scores[0, 1].numpy(), (1 / vocab_size) / 2) + self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2) self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2) + self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change + + def test_repetition_penalty_dist_process(self): + vocab_size, cur_len, input_ids, scores = self._get_repetition_penalty_inputs() + rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) + scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len) + self._check_repetition_penalty_outputs(scores, vocab_size) + + def test_repetition_penalty_dist_process_xla(self): + vocab_size, cur_len, input_ids, scores = self._get_repetition_penalty_inputs() + rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) + rep_penalty_proc = tf.function(rep_penalty_proc, jit_compile=True) # added line wrt non-XLA test + scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len) + self._check_repetition_penalty_outputs(scores, vocab_size) def test_top_k_dist_warper(self): input_ids = None