Adding FlaxNoRepeatNGramLogitsProcessor (#29677)
* fix issue with logit processor in beam search in Flax * adding FlaxNoRepeatNGramLogitsProcessor class + unit test * style correction and code verification * add FlaxNoRepeatNGramLogitsProcessor to the test_processor_list and test_processor_list_jitted tests * fix an issue where ngrams are banned only if they appear ==1 time + update description of get_previous_ngrams * replace non-jit compatible masking of ngrams that are not yet generated with jittable version * Revert "fix issue with logit processor in beam search in Flax" This reverts commit 09b70d7e4dc32d0cc4db61af09a835a9cd238b50. * add FlaxNoRepeatNGramLogitsProcessor to _get_logits_processor * change the method of casting to boolean of banned tokens indices * fix code style * remove some useless operations + significantly faster computation of update indices using jax.lax.fori_loop * remove useless loop iterations * set some variables that were calculated and used multiple times * fix format
This commit is contained in:
@@ -33,6 +33,7 @@ if is_flax_available():
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxMinLengthLogitsProcessor,
|
||||
FlaxNoRepeatNGramLogitsProcessor,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
@@ -197,6 +198,26 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores = logits_processor(input_ids, scores, cur_len=cur_len)
|
||||
self.assertFalse(jnp.isinf(scores).any())
|
||||
|
||||
def test_no_repeat_ngram_dist_processor(self):
|
||||
vocab_size = 3
|
||||
batch_size = 2
|
||||
|
||||
cur_len = 4
|
||||
input_ids = np.array([[1, 1, 2, 1], [0, 1, 0, 1]], dtype="i4")
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
|
||||
no_repeat_proc_2_gram = FlaxNoRepeatNGramLogitsProcessor(2)
|
||||
no_repeat_proc_3_gram = FlaxNoRepeatNGramLogitsProcessor(3)
|
||||
|
||||
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores, cur_len=cur_len)
|
||||
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores, cur_len=cur_len)
|
||||
|
||||
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
|
||||
self.assertListEqual(jnp.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]])
|
||||
|
||||
# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch
|
||||
self.assertListEqual(jnp.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]])
|
||||
|
||||
def test_processor_list(self):
|
||||
batch_size = 4
|
||||
sequence_length = 10
|
||||
@@ -216,6 +237,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
|
||||
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||
no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2)
|
||||
|
||||
# instantiate all logits processors
|
||||
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
@@ -231,10 +253,19 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = no_repeat_proc(input_ids, scores, cur_len=cur_len)
|
||||
|
||||
# with processor list
|
||||
processor = FlaxLogitsProcessorList(
|
||||
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
|
||||
[
|
||||
temp_dist_warp,
|
||||
top_k_warp,
|
||||
top_p_warp,
|
||||
min_dist_proc,
|
||||
bos_dist_proc,
|
||||
eos_dist_proc,
|
||||
no_repeat_proc,
|
||||
]
|
||||
)
|
||||
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
|
||||
|
||||
@@ -263,6 +294,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
|
||||
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||
no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2)
|
||||
|
||||
# instantiate all logits processors
|
||||
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
@@ -279,12 +311,21 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = no_repeat_proc(input_ids, scores, cur_len=cur_len)
|
||||
return scores
|
||||
|
||||
# with processor list
|
||||
def run_processor_list(input_ids, scores, cur_len):
|
||||
processor = FlaxLogitsProcessorList(
|
||||
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
|
||||
[
|
||||
temp_dist_warp,
|
||||
top_k_warp,
|
||||
top_p_warp,
|
||||
min_dist_proc,
|
||||
bos_dist_proc,
|
||||
eos_dist_proc,
|
||||
no_repeat_proc,
|
||||
]
|
||||
)
|
||||
scores = processor(input_ids, scores, cur_len=cur_len)
|
||||
return scores
|
||||
|
||||
Reference in New Issue
Block a user