TF: XLA bad words logits processor and list of processors (#16974)
This commit is contained in:
@@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import List
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -38,7 +38,10 @@ TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
|||||||
[What are input IDs?](../glossary#input-ids)
|
[What are input IDs?](../glossary#input-ids)
|
||||||
scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
|
scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
|
||||||
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
|
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
|
||||||
search or log softmax for each vocabulary token when using beam search
|
search or log softmax for each vocabulary token when using beam search.
|
||||||
|
cur_len (`int`):
|
||||||
|
The current length of valid input sequence tokens. In the TF implementation, the input_ids' sequence length
|
||||||
|
is the maximum length generate can produce, and we need to know which of its tokens are valid.
|
||||||
kwargs:
|
kwargs:
|
||||||
Additional logits processor specific kwargs.
|
Additional logits processor specific kwargs.
|
||||||
|
|
||||||
@@ -51,7 +54,7 @@ class TFLogitsProcessor:
|
|||||||
"""Abstract base class for all logit processors that can be applied during generation."""
|
"""Abstract base class for all logit processors that can be applied during generation."""
|
||||||
|
|
||||||
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||||
"""TF method for processing logits."""
|
"""TF method for processing logits."""
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||||
@@ -62,7 +65,7 @@ class TFLogitsWarper:
|
|||||||
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
||||||
|
|
||||||
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||||
"""TF method for warping logits."""
|
"""TF method for warping logits."""
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||||
@@ -77,18 +80,18 @@ class TFLogitsProcessorList(list):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, **kwargs) -> tf.Tensor:
|
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int, **kwargs) -> tf.Tensor:
|
||||||
for processor in self:
|
for processor in self:
|
||||||
function_args = inspect.signature(processor.__call__).parameters
|
function_args = inspect.signature(processor.__call__).parameters
|
||||||
if len(function_args) > 2:
|
if len(function_args) > 3:
|
||||||
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
|
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Make sure that all the required parameters: {list(function_args.keys())} for "
|
f"Make sure that all the required parameters: {list(function_args.keys())} for "
|
||||||
f"{processor.__class__} are passed to the logits processor."
|
f"{processor.__class__} are passed to the logits processor."
|
||||||
)
|
)
|
||||||
scores = processor(input_ids, scores, **kwargs)
|
scores = processor(input_ids, scores, cur_len, **kwargs)
|
||||||
else:
|
else:
|
||||||
scores = processor(input_ids, scores)
|
scores = processor(input_ids, scores, cur_len)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
@@ -107,7 +110,7 @@ class TFTemperatureLogitsWarper(TFLogitsWarper):
|
|||||||
|
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
|
||||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||||
scores = scores / self.temperature
|
scores = scores / self.temperature
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@@ -133,7 +136,7 @@ class TFTopKLogitsWarper(TFLogitsWarper):
|
|||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||||
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
|
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
|
||||||
# Boolean mask containing all tokens with a probability less than the last token of the top-k
|
# Boolean mask containing all tokens with a probability less than the last token of the top-k
|
||||||
indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
|
indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
|
||||||
@@ -163,7 +166,7 @@ class TFTopPLogitsWarper(TFLogitsWarper):
|
|||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||||
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
|
topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
|
||||||
|
|
||||||
mask_scores = tf.fill(scores.shape, self.filter_value)
|
mask_scores = tf.fill(scores.shape, self.filter_value)
|
||||||
@@ -305,58 +308,75 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
|
|||||||
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
|
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.bad_words_ids = bad_words_ids
|
# stores the information about bad words in three tensors:
|
||||||
|
# 1. a rectangular tensor with the forbidden sequences (padded with `-1`), for full data comparisons
|
||||||
|
self.bad_word_seqs_ids = tf.ragged.constant(bad_words_ids).to_tensor(default_value=-1)
|
||||||
|
# 2. a tensor with the unpadded length of each forbidden sequence, for quick length comparisons
|
||||||
|
bad_word_seqs_len = [len(bad_words) for bad_words in bad_words_ids]
|
||||||
|
if any([word_len == 0 for word_len in bad_word_seqs_len]):
|
||||||
|
raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
|
||||||
|
self.bad_word_seqs_len = tf.convert_to_tensor(bad_word_seqs_len, dtype=tf.int32)
|
||||||
|
# 3. a tensor containing the last token for each sequence, for easy access to the tokens that may be banned
|
||||||
|
self.seq_forbidden_tokens = tf.convert_to_tensor([bad_words[-1] for bad_words in bad_words_ids])
|
||||||
|
|
||||||
def calc_banned_bad_words_ids(self, prev_input_ids):
|
def _calc_row_banned_bad_tokens(self, row_input_ids: tf.Tensor) -> tf.Tensor:
|
||||||
banned_tokens = []
|
def _tokens_match(bad_word_seq_number):
|
||||||
|
def _len_one():
|
||||||
def _tokens_match(prev_tokens, tokens):
|
# If the bad sequence only has one token, always mask it
|
||||||
if len(tokens) == 0:
|
return tf.cond(
|
||||||
# if bad word tokens is just one token always ban it
|
tf.math.equal(self.bad_word_seqs_len[bad_word_seq_number], 1),
|
||||||
return True
|
lambda: tf.ones((), dtype=tf.bool),
|
||||||
if len(tokens) > len(prev_tokens):
|
_len_greater_than_cur_len,
|
||||||
# if bad word tokens are longer than prev tokens they can't be equal
|
|
||||||
return False
|
|
||||||
|
|
||||||
if prev_tokens[-len(tokens) :] == tokens:
|
|
||||||
# if tokens match
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
for prev_input_ids_slice in prev_input_ids:
|
|
||||||
banned_tokens_slice = []
|
|
||||||
|
|
||||||
for banned_token_seq in self.bad_words_ids:
|
|
||||||
assert (
|
|
||||||
len(banned_token_seq) > 0
|
|
||||||
), f"Banned words token sequences {self.bad_words_ids} cannot have an empty list"
|
|
||||||
|
|
||||||
if _tokens_match(prev_input_ids_slice.numpy().tolist(), banned_token_seq[:-1]) is False:
|
|
||||||
# if tokens do not match continue
|
|
||||||
continue
|
|
||||||
|
|
||||||
banned_tokens_slice.append(banned_token_seq[-1])
|
|
||||||
|
|
||||||
banned_tokens.append(banned_tokens_slice)
|
|
||||||
|
|
||||||
return banned_tokens
|
|
||||||
|
|
||||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
|
||||||
|
|
||||||
vocab_size = scores.shape[-1]
|
|
||||||
|
|
||||||
# calculate a list of banned tokens according to bad words
|
|
||||||
banned_tokens = self.calc_banned_bad_words_ids(input_ids[:, :cur_len])
|
|
||||||
|
|
||||||
banned_tokens_indices_mask = []
|
|
||||||
for banned_tokens_slice in banned_tokens:
|
|
||||||
banned_tokens_indices_mask.append(
|
|
||||||
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)
|
def _len_greater_than_cur_len():
|
||||||
|
# Otherwise, if the bad sequence is longer than the current length they can't ever match
|
||||||
|
return tf.cond(
|
||||||
|
tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], row_input_ids.shape[0]),
|
||||||
|
lambda: tf.zeros((), dtype=tf.bool),
|
||||||
|
_match_found,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _match_found():
|
||||||
|
# Finaly, runs the actual comparison. Can only be called if the previous comparisons do not yield
|
||||||
|
# an answer (otherwise we get indexing exceptions)
|
||||||
|
compare_len = self.bad_word_seqs_len[bad_word_seq_number] - 1
|
||||||
|
return tf.cond(
|
||||||
|
tf.math.reduce_all(
|
||||||
|
tf.math.equal(
|
||||||
|
row_input_ids[-compare_len:], self.bad_word_seqs_ids[bad_word_seq_number, :compare_len]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
lambda: tf.ones((), dtype=tf.bool),
|
||||||
|
lambda: tf.zeros((), dtype=tf.bool),
|
||||||
|
)
|
||||||
|
|
||||||
|
match = _len_one()
|
||||||
|
return match
|
||||||
|
|
||||||
|
# Compares the current row against all bad word sequences, obtaining a mask with the matches.
|
||||||
|
match_mask = tf.map_fn(_tokens_match, tf.range(self.bad_word_seqs_ids.shape[0]), fn_output_signature=tf.bool)
|
||||||
|
row_banned_tokens = self.seq_forbidden_tokens[match_mask]
|
||||||
|
return row_banned_tokens
|
||||||
|
|
||||||
|
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||||
|
# We want to mask some banned tokens, at a score level. Since the banned tokens depend on the previous
|
||||||
|
# `input_ids`, they may have a different length for each row, and they may even be empty for some rows.
|
||||||
|
# To remain simple and XLA-compatible, we work on a per-row fashion.
|
||||||
|
# TODO (Joao): this function might trigger XLA retracing as `cur_len` increases. Fix it if it becomes
|
||||||
|
# a frequent choke point. (make `cur_len` a tensor?)
|
||||||
|
def _get_row_updated_score(row_inputs: Tuple[tf.Tensor]) -> tf.Tensor:
|
||||||
|
row_input_ids, row_score = row_inputs
|
||||||
|
banned_tokens = self._calc_row_banned_bad_tokens(row_input_ids[:cur_len])
|
||||||
|
banned_tokens_mask = tf.scatter_nd(
|
||||||
|
indices=tf.expand_dims(banned_tokens, axis=-1),
|
||||||
|
updates=tf.ones_like(banned_tokens, dtype=tf.bool),
|
||||||
|
shape=row_score.shape,
|
||||||
|
)
|
||||||
|
row_score = tf.where(banned_tokens_mask, -float("inf"), row_score)
|
||||||
|
return row_score
|
||||||
|
|
||||||
|
scores = tf.map_fn(_get_row_updated_score, (input_ids, scores), fn_output_signature=tf.float32)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
@@ -401,6 +421,11 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
|
|||||||
|
|
||||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||||
|
|
||||||
|
# TODO (joao): enable XLA on this logits processor. See discussion and attempts in
|
||||||
|
# https://github.com/huggingface/transformers/pull/16974
|
||||||
|
if not tf.executing_eagerly():
|
||||||
|
raise NotImplementedError("TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.")
|
||||||
|
|
||||||
batch_size, vocab_size = scores.shape
|
batch_size, vocab_size = scores.shape
|
||||||
banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)
|
banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)
|
||||||
|
|
||||||
|
|||||||
@@ -2030,7 +2030,7 @@ class TFGenerationMixin:
|
|||||||
if not use_xla:
|
if not use_xla:
|
||||||
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
|
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
|
||||||
input_ids = tf.transpose(input_ids[: current_pos[0]])
|
input_ids = tf.transpose(input_ids[: current_pos[0]])
|
||||||
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=current_pos[0])
|
next_tokens_scores = logits_processor(input_ids, next_token_logits, current_pos[0])
|
||||||
|
|
||||||
# argmax
|
# argmax
|
||||||
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
|
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
|
||||||
@@ -2301,8 +2301,8 @@ class TFGenerationMixin:
|
|||||||
if not use_xla:
|
if not use_xla:
|
||||||
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
|
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
|
||||||
input_ids = tf.transpose(input_ids[:cur_len])
|
input_ids = tf.transpose(input_ids[:cur_len])
|
||||||
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len)
|
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
|
||||||
next_tokens_scores = logits_warper(input_ids, next_tokens_scores)
|
next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len)
|
||||||
|
|
||||||
# sample
|
# sample
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
@@ -2726,7 +2726,7 @@ class TFGenerationMixin:
|
|||||||
# add new logprobs to existing running logprobs scores.
|
# add new logprobs to existing running logprobs scores.
|
||||||
log_probs = tf.nn.log_softmax(logits)
|
log_probs = tf.nn.log_softmax(logits)
|
||||||
log_probs = logits_processor(
|
log_probs = logits_processor(
|
||||||
flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len=cur_len
|
flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len
|
||||||
)
|
)
|
||||||
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
|
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
|
||||||
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
|
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
@parameterized.expand([(False,), (True,)])
|
@parameterized.expand([(False,), (True,)])
|
||||||
def test_temperature_dist_warper(self, use_xla):
|
def test_temperature_dist_warper(self, use_xla):
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
cur_len = None
|
||||||
length = 20
|
length = 20
|
||||||
|
|
||||||
scores = self._get_uniform_logits(batch_size=2, length=length)
|
scores = self._get_uniform_logits(batch_size=2, length=length)
|
||||||
@@ -94,8 +95,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
temp_dist_warper_sharper = tf.function(temp_dist_warper_sharper, jit_compile=True)
|
temp_dist_warper_sharper = tf.function(temp_dist_warper_sharper, jit_compile=True)
|
||||||
temp_dist_warper_smoother = tf.function(temp_dist_warper_smoother, jit_compile=True)
|
temp_dist_warper_smoother = tf.function(temp_dist_warper_smoother, jit_compile=True)
|
||||||
|
|
||||||
warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores)), axis=-1)
|
warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores), cur_len), axis=-1)
|
||||||
warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores)), axis=-1)
|
warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores), cur_len), axis=-1)
|
||||||
|
|
||||||
# uniform distribution stays uniform
|
# uniform distribution stays uniform
|
||||||
tf.debugging.assert_near(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)
|
tf.debugging.assert_near(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)
|
||||||
@@ -142,6 +143,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
@parameterized.expand([(False,), (True,)])
|
@parameterized.expand([(False,), (True,)])
|
||||||
def test_top_k_dist_warper(self, use_xla):
|
def test_top_k_dist_warper(self, use_xla):
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
cur_len = None
|
||||||
vocab_size = 10
|
vocab_size = 10
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
|
|
||||||
@@ -153,7 +155,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
if use_xla:
|
if use_xla:
|
||||||
top_k_warp = tf.function(top_k_warp, jit_compile=True)
|
top_k_warp = tf.function(top_k_warp, jit_compile=True)
|
||||||
|
|
||||||
scores = top_k_warp(input_ids, ramp_logits)
|
scores = top_k_warp(input_ids, ramp_logits, cur_len)
|
||||||
|
|
||||||
# check that correct tokens are filtered
|
# check that correct tokens are filtered
|
||||||
self.assertListEqual(tf.math.is_inf(scores[0]).numpy().tolist(), 7 * [True] + 3 * [False])
|
self.assertListEqual(tf.math.is_inf(scores[0]).numpy().tolist(), 7 * [True] + 3 * [False])
|
||||||
@@ -167,12 +169,12 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
if use_xla:
|
if use_xla:
|
||||||
top_k_warp_safety_check = tf.function(top_k_warp_safety_check, jit_compile=True)
|
top_k_warp_safety_check = tf.function(top_k_warp_safety_check, jit_compile=True)
|
||||||
|
|
||||||
scores = top_k_warp_safety_check(input_ids, logits)
|
scores = top_k_warp_safety_check(input_ids, logits, cur_len)
|
||||||
# uniform dist is not changed
|
# uniform dist is not changed
|
||||||
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [0, 0])
|
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [0, 0])
|
||||||
|
|
||||||
ramp_logits = np.broadcast_to(np.arange(length, dtype=np.float32), (batch_size, length)).copy()
|
ramp_logits = np.broadcast_to(np.arange(length, dtype=np.float32), (batch_size, length)).copy()
|
||||||
scores = top_k_warp_safety_check(input_ids, ramp_logits)
|
scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len)
|
||||||
|
|
||||||
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
|
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
|
||||||
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [2, 2])
|
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [2, 2])
|
||||||
@@ -180,6 +182,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
@parameterized.expand([(False,), (True,)])
|
@parameterized.expand([(False,), (True,)])
|
||||||
def test_top_p_dist_warper(self, use_xla):
|
def test_top_p_dist_warper(self, use_xla):
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
cur_len = None
|
||||||
vocab_size = 10
|
vocab_size = 10
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
|
|
||||||
@@ -189,7 +192,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
top_p_warp = TFTopPLogitsWarper(0.7)
|
top_p_warp = TFTopPLogitsWarper(0.7)
|
||||||
if use_xla:
|
if use_xla:
|
||||||
top_p_warp = tf.function(top_p_warp, jit_compile=True)
|
top_p_warp = tf.function(top_p_warp, jit_compile=True)
|
||||||
filtered_dist = tf.exp(top_p_warp(input_ids, dist))
|
filtered_dist = tf.exp(top_p_warp(input_ids, dist, cur_len))
|
||||||
|
|
||||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||||
# exp (-inf) => 0
|
# exp (-inf) => 0
|
||||||
@@ -208,7 +211,7 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
top_p_warp = TFTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
|
top_p_warp = TFTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
|
||||||
if use_xla:
|
if use_xla:
|
||||||
top_p_warp = tf.function(top_p_warp, jit_compile=True)
|
top_p_warp = tf.function(top_p_warp, jit_compile=True)
|
||||||
filtered_dist = top_p_warp(input_ids, ramp_logits)
|
filtered_dist = top_p_warp(input_ids, ramp_logits, cur_len)
|
||||||
|
|
||||||
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
|
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
|
||||||
# 2.
|
# 2.
|
||||||
@@ -242,7 +245,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
tf.math.is_inf(filtered_scores_3_gram).numpy().tolist(), [[False, False, False], [True, False, False]]
|
tf.math.is_inf(filtered_scores_3_gram).numpy().tolist(), [[False, False, False], [True, False, False]]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_no_bad_words_dist_processor(self):
|
@parameterized.expand([(False,), (True,)])
|
||||||
|
def test_no_bad_words_dist_processor(self, use_xla):
|
||||||
vocab_size = 5
|
vocab_size = 5
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
eos_token_id = 4
|
eos_token_id = 4
|
||||||
@@ -255,6 +259,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
|
||||||
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)
|
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)
|
||||||
|
if use_xla:
|
||||||
|
no_bad_words_dist_proc = tf.function(no_bad_words_dist_proc, jit_compile=True)
|
||||||
|
|
||||||
filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores), cur_len)
|
filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores), cur_len)
|
||||||
|
|
||||||
@@ -322,7 +328,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
scores = logits_processor(input_ids, scores, cur_len)
|
scores = logits_processor(input_ids, scores, cur_len)
|
||||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||||
|
|
||||||
def test_processor_list(self):
|
@parameterized.expand([(False,), (True,)])
|
||||||
|
def test_processor_list(self, use_xla):
|
||||||
|
# TODO (Joao): reintroduce TFNoRepeatNGramLogitsProcessor when it gets compatible with XLA
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
cur_len = 10
|
cur_len = 10
|
||||||
vocab_size = 15
|
vocab_size = 15
|
||||||
@@ -341,16 +349,24 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||||
top_k_warp = TFTopKLogitsWarper(3)
|
top_k_warp = TFTopKLogitsWarper(3)
|
||||||
top_p_warp = TFTopPLogitsWarper(0.8)
|
top_p_warp = TFTopPLogitsWarper(0.8)
|
||||||
no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
|
# no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
|
||||||
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
|
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
|
||||||
|
if use_xla:
|
||||||
|
min_dist_proc = tf.function(min_dist_proc, jit_compile=True)
|
||||||
|
temp_dist_warp = tf.function(temp_dist_warp, jit_compile=True)
|
||||||
|
rep_penalty_proc = tf.function(rep_penalty_proc, jit_compile=True)
|
||||||
|
top_k_warp = tf.function(top_k_warp, jit_compile=True)
|
||||||
|
top_p_warp = tf.function(top_p_warp, jit_compile=True)
|
||||||
|
# no_repeat_proc = tf.function(no_repeat_proc, jit_compile=True)
|
||||||
|
no_bad_words_dist_proc = tf.function(no_bad_words_dist_proc, jit_compile=True)
|
||||||
|
|
||||||
# no processor list
|
# no processor list
|
||||||
scores = min_dist_proc(input_ids, scores, cur_len)
|
scores = min_dist_proc(input_ids, scores, cur_len)
|
||||||
scores = temp_dist_warp(input_ids, scores)
|
scores = temp_dist_warp(input_ids, scores, cur_len)
|
||||||
scores = rep_penalty_proc(input_ids, scores, cur_len)
|
scores = rep_penalty_proc(input_ids, scores, cur_len)
|
||||||
scores = top_k_warp(input_ids, scores)
|
scores = top_k_warp(input_ids, scores, cur_len)
|
||||||
scores = top_p_warp(input_ids, scores)
|
scores = top_p_warp(input_ids, scores, cur_len)
|
||||||
scores = no_repeat_proc(input_ids, scores, cur_len)
|
# scores = no_repeat_proc(input_ids, scores, cur_len)
|
||||||
scores = no_bad_words_dist_proc(input_ids, scores, cur_len)
|
scores = no_bad_words_dist_proc(input_ids, scores, cur_len)
|
||||||
|
|
||||||
# with processor list
|
# with processor list
|
||||||
@@ -361,11 +377,11 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
rep_penalty_proc,
|
rep_penalty_proc,
|
||||||
top_k_warp,
|
top_k_warp,
|
||||||
top_p_warp,
|
top_p_warp,
|
||||||
no_repeat_proc,
|
# no_repeat_proc,
|
||||||
no_bad_words_dist_proc,
|
no_bad_words_dist_proc,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
|
scores_comp = processor(input_ids, scores_comp, cur_len)
|
||||||
|
|
||||||
# remove inf
|
# remove inf
|
||||||
scores = tf.where(tf.math.is_inf(scores), -1e9, scores)
|
scores = tf.where(tf.math.is_inf(scores), -1e9, scores)
|
||||||
|
|||||||
Reference in New Issue
Block a user