From 809dac48f97cc75ed12e17e7ba739c05fec4c928 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 25 Apr 2022 19:27:53 +0100 Subject: [PATCH] TF: XLA logits processors - minimum length, forced eos, and forced bos (#16912) * XLA min len, forced eos, and forced bos Co-authored-by: Matt --- .../generation_tf_logits_process.py | 17 +++++++++++------ .../test_generation_tf_logits_process.py | 16 +++++++++++++--- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation_tf_logits_process.py b/src/transformers/generation_tf_logits_process.py index eefd1f0ace..b771211cea 100644 --- a/src/transformers/generation_tf_logits_process.py +++ b/src/transformers/generation_tf_logits_process.py @@ -215,13 +215,18 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor): self.min_length = min_length self.eos_token_id = eos_token_id - def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: - # TODO(Matt) - this if statement has to be rewritten for XLA. Leaving it now though since - # generate is not XLA - compileable anyways - if cur_len < self.min_length: - eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape) - scores = tf.where(eos_token_id_mask, float("-inf"), scores) + def _apply_eos_token_mask(self, scores: tf.Tensor) -> tf.Tensor: + eos_token_id_mask = tf.range(scores.shape[-1]) == self.eos_token_id + scores = tf.where(eos_token_id_mask, float("-inf"), scores) + return scores + def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor: + # applies eos token masking if the first argument is true + scores = tf.cond( + tf.less(cur_len, self.min_length), + lambda: self._apply_eos_token_mask(scores), + lambda: tf.identity(scores), + ) return scores diff --git a/tests/generation/test_generation_tf_logits_process.py b/tests/generation/test_generation_tf_logits_process.py index 913b26cb64..06b8e001c0 100644 --- a/tests/generation/test_generation_tf_logits_process.py +++ b/tests/generation/test_generation_tf_logits_process.py @@ -18,6 +18,7 @@ import unittest import numpy as np +from parameterized import parameterized from transformers import is_tf_available from transformers.testing_utils import require_tf @@ -47,12 +48,15 @@ class TFLogitsProcessorTest(unittest.TestCase): scores = tf.ones((batch_size, length), dtype=tf.float32) / length return scores - def test_min_length_dist_processor(self): + @parameterized.expand([(False,), (True,)]) + def test_min_length_dist_processor(self, use_xla): vocab_size = 20 batch_size = 4 eos_token_id = 0 min_dist_processor = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) + if use_xla: + min_dist_processor = tf.function(min_dist_processor, jit_compile=True) # check that min length is applied at length 5 cur_len = 5 @@ -256,12 +260,15 @@ class TFLogitsProcessorTest(unittest.TestCase): [[True, True, False, True, True], [True, True, True, False, True]], ) - def test_forced_bos_token_logits_processor(self): + @parameterized.expand([(False,), (True,)]) + def test_forced_bos_token_logits_processor(self, use_xla): vocab_size = 20 batch_size = 4 bos_token_id = 0 logits_processor = TFForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id) + if use_xla: + logits_processor = tf.function(logits_processor, jit_compile=True) # check that all scores are -inf except the bos_token_id score cur_len = 1 @@ -280,13 +287,16 @@ class TFLogitsProcessorTest(unittest.TestCase): scores = logits_processor(input_ids, scores, cur_len) self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores)))) - def test_forced_eos_token_logits_processor(self): + @parameterized.expand([(False,), (True,)]) + def test_forced_eos_token_logits_processor(self, use_xla): vocab_size = 20 batch_size = 4 eos_token_id = 0 max_length = 5 logits_processor = TFForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id) + if use_xla: + logits_processor = tf.function(logits_processor, jit_compile=True) # check that all scores are -inf except the eos_token_id when max_length-1 is reached cur_len = 4