diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 315d5b08a7..6653f3c8d1 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -162,6 +162,7 @@ else: "FlaxTopKLogitsWarper", "FlaxTopPLogitsWarper", "FlaxWhisperTimeStampLogitsProcessor", + "FlaxNoRepeatNGramLogitsProcessor", ] _import_structure["flax_utils"] = [ "FlaxGenerationMixin", @@ -294,6 +295,7 @@ if TYPE_CHECKING: FlaxLogitsProcessorList, FlaxLogitsWarper, FlaxMinLengthLogitsProcessor, + FlaxNoRepeatNGramLogitsProcessor, FlaxSuppressTokensAtBeginLogitsProcessor, FlaxSuppressTokensLogitsProcessor, FlaxTemperatureLogitsWarper, diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index 5c30b92755..84b5a38d5d 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -18,6 +18,7 @@ import inspect import jax import jax.lax as lax import jax.numpy as jnp +from jax.experimental import sparse from ..utils import add_start_docstrings from ..utils.logging import get_logger @@ -455,3 +456,89 @@ class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor): scores = jax.vmap(handle_cumulative_probs)(logprobs, scores) return scores + + +class FlaxNoRepeatNGramLogitsProcessor(FlaxLogitsProcessor): + r""" + [`FlaxLogitsProcessor`] that enforces no repetition of n-grams. See + [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345). + + Args: + ngram_size (`int`): + All ngrams of size `ngram_size` can only occur once. + """ + + def __init__(self, ngram_size: int): + if not isinstance(ngram_size, int) or ngram_size <= 0: + raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") + self.ngram_size = ngram_size + + def get_previous_ngrams(self, input_ids: jnp.ndarray, vocab_size: int, cur_len: int): + """ + get a matrix of size (batch_size,) + (vocab_size,)*n (for n-grams) that + represent the n-grams that occured previously. + The BCOO representation allow to store only the few non-zero entries, instead of the full (huge) matrix + """ + batch_size, seq_len = input_ids.shape + # number of n-grams in the whole sequence + seq_ngrams = seq_len - (self.ngram_size - 1) + # number of n-grams in the currently generated sequence + cur_ngrams = cur_len - (self.ngram_size - 1) + + def body_fun(i, val): + b = i % batch_size + pos = i // batch_size + return val.at[i].set( + jnp.array( + [ + b, + ] + + [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)] + ) + ) + + shape = (batch_size * seq_ngrams, self.ngram_size + 1) + all_update_indices = jax.lax.fori_loop( + 0, batch_size * cur_ngrams, body_fun, jnp.zeros(shape, dtype=input_ids.dtype) + ) + + # ignore the n-grams not yet generated + data = (jnp.arange(batch_size * seq_ngrams) < batch_size * cur_ngrams).astype("float32") + + return sparse.BCOO((data, all_update_indices), shape=(batch_size,) + (vocab_size,) * self.ngram_size) + + def get_banned_tokens_mask(self, latest_tokens: jnp.ndarray, previous_ngrams) -> jnp.ndarray: + """ + Determines which tokens must be banned given latest tokens and the previously seen + ngrams. + """ + + @sparse.sparsify + @jax.vmap + def inner_fn(latest_tokens, previous_ngrams): + return previous_ngrams[tuple(latest_tokens)] + + return sparse.bcoo_todense(inner_fn(latest_tokens, previous_ngrams)) + + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: + def true_fn(): + _, vocab_size = scores.shape + # store the previously seen n-grams + previous_ngrams = self.get_previous_ngrams(input_ids, vocab_size, cur_len) + + # get the n-1 last tokens that prefix the n-gram being generated + latest_tokens = jnp.zeros((input_ids.shape[0], self.ngram_size - 1), dtype=input_ids.dtype) + latest_tokens = jax.lax.dynamic_update_slice( + latest_tokens, + jax.lax.dynamic_slice( + input_ids, (0, cur_len - (self.ngram_size - 1)), (input_ids.shape[0], (self.ngram_size - 1)) + ), + (0, 0), + ) + + # compute the banned tokens, ie all the tokens that when added to the latest tokens lead to a n-gram that was previously generated + banned_tokens_indices_mask = self.get_banned_tokens_mask(latest_tokens, previous_ngrams).astype("bool") + return jnp.where(banned_tokens_indices_mask, -float("inf"), scores) + + output = jax.lax.cond((cur_len >= self.ngram_size - 1), true_fn, lambda: scores) + return output diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 3a89c1ed41..08480ac983 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -40,6 +40,7 @@ from .flax_logits_process import ( FlaxForceTokensLogitsProcessor, FlaxLogitsProcessorList, FlaxMinLengthLogitsProcessor, + FlaxNoRepeatNGramLogitsProcessor, FlaxSuppressTokensAtBeginLogitsProcessor, FlaxSuppressTokensLogitsProcessor, FlaxTemperatureLogitsWarper, @@ -534,6 +535,8 @@ class FlaxGenerationMixin: [input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids ] processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids)) + if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: + processors.append(FlaxNoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)) processors = self._merge_criteria_processor_list(processors, logits_processor) return processors diff --git a/tests/generation/test_flax_logits_process.py b/tests/generation/test_flax_logits_process.py index a45d75ae24..bd5f8f648c 100644 --- a/tests/generation/test_flax_logits_process.py +++ b/tests/generation/test_flax_logits_process.py @@ -33,6 +33,7 @@ if is_flax_available(): FlaxForcedEOSTokenLogitsProcessor, FlaxLogitsProcessorList, FlaxMinLengthLogitsProcessor, + FlaxNoRepeatNGramLogitsProcessor, FlaxTemperatureLogitsWarper, FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, @@ -197,6 +198,26 @@ class LogitsProcessorTest(unittest.TestCase): scores = logits_processor(input_ids, scores, cur_len=cur_len) self.assertFalse(jnp.isinf(scores).any()) + def test_no_repeat_ngram_dist_processor(self): + vocab_size = 3 + batch_size = 2 + + cur_len = 4 + input_ids = np.array([[1, 1, 2, 1], [0, 1, 0, 1]], dtype="i4") + scores = self._get_uniform_logits(batch_size, vocab_size) + + no_repeat_proc_2_gram = FlaxNoRepeatNGramLogitsProcessor(2) + no_repeat_proc_3_gram = FlaxNoRepeatNGramLogitsProcessor(3) + + filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores, cur_len=cur_len) + filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores, cur_len=cur_len) + + # 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch + self.assertListEqual(jnp.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]]) + + # 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch + self.assertListEqual(jnp.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]]) + def test_processor_list(self): batch_size = 4 sequence_length = 10 @@ -216,6 +237,7 @@ class LogitsProcessorTest(unittest.TestCase): temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5) top_k_warp = FlaxTopKLogitsWarper(3) top_p_warp = FlaxTopPLogitsWarper(0.8) + no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2) # instantiate all logits processors min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) @@ -231,10 +253,19 @@ class LogitsProcessorTest(unittest.TestCase): scores = min_dist_proc(input_ids, scores, cur_len=cur_len) scores = bos_dist_proc(input_ids, scores, cur_len=cur_len) scores = eos_dist_proc(input_ids, scores, cur_len=cur_len) + scores = no_repeat_proc(input_ids, scores, cur_len=cur_len) # with processor list processor = FlaxLogitsProcessorList( - [temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc] + [ + temp_dist_warp, + top_k_warp, + top_p_warp, + min_dist_proc, + bos_dist_proc, + eos_dist_proc, + no_repeat_proc, + ] ) scores_comp = processor(input_ids, scores_comp, cur_len=cur_len) @@ -263,6 +294,7 @@ class LogitsProcessorTest(unittest.TestCase): temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5) top_k_warp = FlaxTopKLogitsWarper(3) top_p_warp = FlaxTopPLogitsWarper(0.8) + no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2) # instantiate all logits processors min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) @@ -279,12 +311,21 @@ class LogitsProcessorTest(unittest.TestCase): scores = min_dist_proc(input_ids, scores, cur_len=cur_len) scores = bos_dist_proc(input_ids, scores, cur_len=cur_len) scores = eos_dist_proc(input_ids, scores, cur_len=cur_len) + scores = no_repeat_proc(input_ids, scores, cur_len=cur_len) return scores # with processor list def run_processor_list(input_ids, scores, cur_len): processor = FlaxLogitsProcessorList( - [temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc] + [ + temp_dist_warp, + top_k_warp, + top_p_warp, + min_dist_proc, + bos_dist_proc, + eos_dist_proc, + no_repeat_proc, + ] ) scores = processor(input_ids, scores, cur_len=cur_len) return scores