TF: XLA logits processors - minimum length, forced eos, and forced bos (#16912)
* XLA min len, forced eos, and forced bos Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
@@ -18,6 +18,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import is_tf_available
|
||||
from transformers.testing_utils import require_tf
|
||||
|
||||
@@ -47,12 +48,15 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
scores = tf.ones((batch_size, length), dtype=tf.float32) / length
|
||||
return scores
|
||||
|
||||
def test_min_length_dist_processor(self):
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_min_length_dist_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
|
||||
min_dist_processor = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
if use_xla:
|
||||
min_dist_processor = tf.function(min_dist_processor, jit_compile=True)
|
||||
|
||||
# check that min length is applied at length 5
|
||||
cur_len = 5
|
||||
@@ -256,12 +260,15 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
[[True, True, False, True, True], [True, True, True, False, True]],
|
||||
)
|
||||
|
||||
def test_forced_bos_token_logits_processor(self):
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_forced_bos_token_logits_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
bos_token_id = 0
|
||||
|
||||
logits_processor = TFForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
||||
if use_xla:
|
||||
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||
|
||||
# check that all scores are -inf except the bos_token_id score
|
||||
cur_len = 1
|
||||
@@ -280,13 +287,16 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||
|
||||
def test_forced_eos_token_logits_processor(self):
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_forced_eos_token_logits_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
max_length = 5
|
||||
|
||||
logits_processor = TFForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||
if use_xla:
|
||||
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||
|
||||
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
|
||||
cur_len = 4
|
||||
|
||||
Reference in New Issue
Block a user