Adding FlaxNoRepeatNGramLogitsProcessor (#29677)

* fix issue with logit processor in beam search in Flax

* adding FlaxNoRepeatNGramLogitsProcessor class + unit test

* style correction and code verification

* add FlaxNoRepeatNGramLogitsProcessor to the test_processor_list and test_processor_list_jitted tests

* fix an issue where ngrams are banned only if they appear ==1 time + update description of get_previous_ngrams

* replace non-jit compatible masking of ngrams that are not yet generated with jittable version

* Revert "fix issue with logit processor in beam search in Flax"

This reverts commit 09b70d7e4dc32d0cc4db61af09a835a9cd238b50.

* add FlaxNoRepeatNGramLogitsProcessor to _get_logits_processor

* change the method of casting to boolean of banned tokens indices

* fix code style

* remove some useless operations + significantly faster computation of update indices using jax.lax.fori_loop

* remove useless loop iterations

* set some variables that were calculated and used multiple times

* fix format
This commit is contained in:
théo gigant
2024-04-02 11:39:33 +02:00
committed by GitHub
parent 33288ff150
commit fed27ffc7e
4 changed files with 135 additions and 2 deletions

View File

@@ -33,6 +33,7 @@ if is_flax_available():
FlaxForcedEOSTokenLogitsProcessor,
FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor,
FlaxNoRepeatNGramLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
@@ -197,6 +198,26 @@ class LogitsProcessorTest(unittest.TestCase):
scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertFalse(jnp.isinf(scores).any())
def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3
batch_size = 2
cur_len = 4
input_ids = np.array([[1, 1, 2, 1], [0, 1, 0, 1]], dtype="i4")
scores = self._get_uniform_logits(batch_size, vocab_size)
no_repeat_proc_2_gram = FlaxNoRepeatNGramLogitsProcessor(2)
no_repeat_proc_3_gram = FlaxNoRepeatNGramLogitsProcessor(3)
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores, cur_len=cur_len)
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores, cur_len=cur_len)
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(jnp.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]])
# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(jnp.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]])
def test_processor_list(self):
batch_size = 4
sequence_length = 10
@@ -216,6 +237,7 @@ class LogitsProcessorTest(unittest.TestCase):
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
top_k_warp = FlaxTopKLogitsWarper(3)
top_p_warp = FlaxTopPLogitsWarper(0.8)
no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2)
# instantiate all logits processors
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
@@ -231,10 +253,19 @@ class LogitsProcessorTest(unittest.TestCase):
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = no_repeat_proc(input_ids, scores, cur_len=cur_len)
# with processor list
processor = FlaxLogitsProcessorList(
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
[
temp_dist_warp,
top_k_warp,
top_p_warp,
min_dist_proc,
bos_dist_proc,
eos_dist_proc,
no_repeat_proc,
]
)
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
@@ -263,6 +294,7 @@ class LogitsProcessorTest(unittest.TestCase):
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
top_k_warp = FlaxTopKLogitsWarper(3)
top_p_warp = FlaxTopPLogitsWarper(0.8)
no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2)
# instantiate all logits processors
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
@@ -279,12 +311,21 @@ class LogitsProcessorTest(unittest.TestCase):
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = no_repeat_proc(input_ids, scores, cur_len=cur_len)
return scores
# with processor list
def run_processor_list(input_ids, scores, cur_len):
processor = FlaxLogitsProcessorList(
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
[
temp_dist_warp,
top_k_warp,
top_p_warp,
min_dist_proc,
bos_dist_proc,
eos_dist_proc,
no_repeat_proc,
]
)
scores = processor(input_ids, scores, cur_len=cur_len)
return scores