TF: XLA repetition penalty (#16879)

This commit is contained in:
Joao Gante
2022-04-22 18:29:32 +01:00
committed by GitHub
parent ec81c11a18
commit 99c8226b12
2 changed files with 43 additions and 18 deletions

View File

@@ -101,7 +101,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
self.assertGreater(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_smooth[1, :]))
self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :]))
def test_repetition_penalty_dist_process(self):
def _get_repetition_penalty_inputs(self):
vocab_size = 10
cur_len = 2
@@ -114,17 +114,31 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = tf.where(mask, -1 / vocab_size, scores)
mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool)
scores = tf.where(mask, 4 / vocab_size, scores)
return vocab_size, cur_len, input_ids, scores
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
# check that values were correctly changed
def _check_repetition_penalty_outputs(self, scores, vocab_size):
# check that values were correctly changed (negative scores for used tokens should increase, others
# should decrease)
self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2)
self.assertAlmostEqual(scores[0, 1].numpy(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change
self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2)
self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change
def test_repetition_penalty_dist_process(self):
vocab_size, cur_len, input_ids, scores = self._get_repetition_penalty_inputs()
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
self._check_repetition_penalty_outputs(scores, vocab_size)
def test_repetition_penalty_dist_process_xla(self):
vocab_size, cur_len, input_ids, scores = self._get_repetition_penalty_inputs()
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
rep_penalty_proc = tf.function(rep_penalty_proc, jit_compile=True) # added line wrt non-XLA test
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
self._check_repetition_penalty_outputs(scores, vocab_size)
def test_top_k_dist_warper(self):
input_ids = None