TF: XLA logits processors - minimum length, forced eos, and forced bos (#16912)
* XLA min len, forced eos, and forced bos Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
@@ -215,13 +215,18 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor):
|
|||||||
self.min_length = min_length
|
self.min_length = min_length
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
def _apply_eos_token_mask(self, scores: tf.Tensor) -> tf.Tensor:
|
||||||
# TODO(Matt) - this if statement has to be rewritten for XLA. Leaving it now though since
|
eos_token_id_mask = tf.range(scores.shape[-1]) == self.eos_token_id
|
||||||
# generate is not XLA - compileable anyways
|
|
||||||
if cur_len < self.min_length:
|
|
||||||
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
|
|
||||||
scores = tf.where(eos_token_id_mask, float("-inf"), scores)
|
scores = tf.where(eos_token_id_mask, float("-inf"), scores)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||||
|
# applies eos token masking if the first argument is true
|
||||||
|
scores = tf.cond(
|
||||||
|
tf.less(cur_len, self.min_length),
|
||||||
|
lambda: self._apply_eos_token_mask(scores),
|
||||||
|
lambda: tf.identity(scores),
|
||||||
|
)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from parameterized import parameterized
|
||||||
from transformers import is_tf_available
|
from transformers import is_tf_available
|
||||||
from transformers.testing_utils import require_tf
|
from transformers.testing_utils import require_tf
|
||||||
|
|
||||||
@@ -47,12 +48,15 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
scores = tf.ones((batch_size, length), dtype=tf.float32) / length
|
scores = tf.ones((batch_size, length), dtype=tf.float32) / length
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def test_min_length_dist_processor(self):
|
@parameterized.expand([(False,), (True,)])
|
||||||
|
def test_min_length_dist_processor(self, use_xla):
|
||||||
vocab_size = 20
|
vocab_size = 20
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
eos_token_id = 0
|
eos_token_id = 0
|
||||||
|
|
||||||
min_dist_processor = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
min_dist_processor = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||||
|
if use_xla:
|
||||||
|
min_dist_processor = tf.function(min_dist_processor, jit_compile=True)
|
||||||
|
|
||||||
# check that min length is applied at length 5
|
# check that min length is applied at length 5
|
||||||
cur_len = 5
|
cur_len = 5
|
||||||
@@ -256,12 +260,15 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
[[True, True, False, True, True], [True, True, True, False, True]],
|
[[True, True, False, True, True], [True, True, True, False, True]],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_forced_bos_token_logits_processor(self):
|
@parameterized.expand([(False,), (True,)])
|
||||||
|
def test_forced_bos_token_logits_processor(self, use_xla):
|
||||||
vocab_size = 20
|
vocab_size = 20
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
bos_token_id = 0
|
bos_token_id = 0
|
||||||
|
|
||||||
logits_processor = TFForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
logits_processor = TFForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
||||||
|
if use_xla:
|
||||||
|
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||||
|
|
||||||
# check that all scores are -inf except the bos_token_id score
|
# check that all scores are -inf except the bos_token_id score
|
||||||
cur_len = 1
|
cur_len = 1
|
||||||
@@ -280,13 +287,16 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
scores = logits_processor(input_ids, scores, cur_len)
|
scores = logits_processor(input_ids, scores, cur_len)
|
||||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||||
|
|
||||||
def test_forced_eos_token_logits_processor(self):
|
@parameterized.expand([(False,), (True,)])
|
||||||
|
def test_forced_eos_token_logits_processor(self, use_xla):
|
||||||
vocab_size = 20
|
vocab_size = 20
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
eos_token_id = 0
|
eos_token_id = 0
|
||||||
max_length = 5
|
max_length = 5
|
||||||
|
|
||||||
logits_processor = TFForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
logits_processor = TFForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||||
|
if use_xla:
|
||||||
|
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||||
|
|
||||||
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
|
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
|
||||||
cur_len = 4
|
cur_len = 4
|
||||||
|
|||||||
Reference in New Issue
Block a user