TF: XLA repetition penalty (#16879)
This commit is contained in:
@@ -101,7 +101,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
self.assertGreater(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_smooth[1, :]))
|
||||
self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :]))
|
||||
|
||||
def test_repetition_penalty_dist_process(self):
|
||||
def _get_repetition_penalty_inputs(self):
|
||||
vocab_size = 10
|
||||
cur_len = 2
|
||||
|
||||
@@ -114,17 +114,31 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
scores = tf.where(mask, -1 / vocab_size, scores)
|
||||
mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool)
|
||||
scores = tf.where(mask, 4 / vocab_size, scores)
|
||||
return vocab_size, cur_len, input_ids, scores
|
||||
|
||||
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
|
||||
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
|
||||
|
||||
# check that values were correctly changed
|
||||
def _check_repetition_penalty_outputs(self, scores, vocab_size):
|
||||
# check that values were correctly changed (negative scores for used tokens should increase, others
|
||||
# should decrease)
|
||||
self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2)
|
||||
self.assertAlmostEqual(scores[0, 1].numpy(), (1 / vocab_size) / 2)
|
||||
self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change
|
||||
|
||||
self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2)
|
||||
self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2)
|
||||
self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change
|
||||
|
||||
def test_repetition_penalty_dist_process(self):
|
||||
vocab_size, cur_len, input_ids, scores = self._get_repetition_penalty_inputs()
|
||||
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
|
||||
self._check_repetition_penalty_outputs(scores, vocab_size)
|
||||
|
||||
def test_repetition_penalty_dist_process_xla(self):
|
||||
vocab_size, cur_len, input_ids, scores = self._get_repetition_penalty_inputs()
|
||||
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
rep_penalty_proc = tf.function(rep_penalty_proc, jit_compile=True) # added line wrt non-XLA test
|
||||
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
|
||||
self._check_repetition_penalty_outputs(scores, vocab_size)
|
||||
|
||||
def test_top_k_dist_warper(self):
|
||||
input_ids = None
|
||||
|
||||
Reference in New Issue
Block a user