TF: remove set_tensor_by_indices_to_value (#16729)

This commit is contained in:
Joao Gante
2022-04-12 17:51:47 +01:00
committed by GitHub
parent a315988bae
commit d7f7f29f29
4 changed files with 15 additions and 27 deletions

View File

@@ -37,7 +37,6 @@ if is_tf_available():
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from transformers.tf_utils import set_tensor_by_indices_to_value
from ..test_modeling_tf_common import ids_tensor
@@ -112,9 +111,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
mask = tf.cast(tf.constant([[1] + 9 * [0], 10 * [0]]), tf.bool)
scores = set_tensor_by_indices_to_value(scores, mask, -1 / vocab_size)
scores = tf.where(mask, -1 / vocab_size, scores)
mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool)
scores = set_tensor_by_indices_to_value(scores, mask, 4 / vocab_size)
scores = tf.where(mask, 4 / vocab_size, scores)
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
@@ -340,8 +339,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
# remove inf
scores = set_tensor_by_indices_to_value(scores, tf.math.is_inf(scores), -1e9)
scores_comp = set_tensor_by_indices_to_value(scores_comp, tf.math.is_inf(scores_comp), -1e9)
scores = tf.where(tf.math.is_inf(scores), -1e9, scores)
scores_comp = tf.where(tf.math.is_inf(scores_comp), -1e9, scores_comp)
# scores should be equal
tf.debugging.assert_near(scores, scores_comp, atol=1e-3)