TF: XLA logits processors - minimum length, forced eos, and forced bos (#16912)

* XLA min len, forced eos, and forced bos

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Joao Gante
2022-04-25 19:27:53 +01:00
committed by GitHub
parent f6210c49e2
commit 809dac48f9
2 changed files with 24 additions and 9 deletions

View File

@@ -18,6 +18,7 @@ import unittest
import numpy as np
from parameterized import parameterized
from transformers import is_tf_available
from transformers.testing_utils import require_tf
@@ -47,12 +48,15 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = tf.ones((batch_size, length), dtype=tf.float32) / length
return scores
def test_min_length_dist_processor(self):
@parameterized.expand([(False,), (True,)])
def test_min_length_dist_processor(self, use_xla):
vocab_size = 20
batch_size = 4
eos_token_id = 0
min_dist_processor = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
if use_xla:
min_dist_processor = tf.function(min_dist_processor, jit_compile=True)
# check that min length is applied at length 5
cur_len = 5
@@ -256,12 +260,15 @@ class TFLogitsProcessorTest(unittest.TestCase):
[[True, True, False, True, True], [True, True, True, False, True]],
)
def test_forced_bos_token_logits_processor(self):
@parameterized.expand([(False,), (True,)])
def test_forced_bos_token_logits_processor(self, use_xla):
vocab_size = 20
batch_size = 4
bos_token_id = 0
logits_processor = TFForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
if use_xla:
logits_processor = tf.function(logits_processor, jit_compile=True)
# check that all scores are -inf except the bos_token_id score
cur_len = 1
@@ -280,13 +287,16 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = logits_processor(input_ids, scores, cur_len)
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
def test_forced_eos_token_logits_processor(self):
@parameterized.expand([(False,), (True,)])
def test_forced_eos_token_logits_processor(self, use_xla):
vocab_size = 20
batch_size = 4
eos_token_id = 0
max_length = 5
logits_processor = TFForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
if use_xla:
logits_processor = tf.function(logits_processor, jit_compile=True)
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
cur_len = 4