TF: remove set_tensor_by_indices_to_value (#16729)
This commit is contained in:
@@ -37,7 +37,6 @@ if is_tf_available():
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
)
|
||||
from transformers.tf_utils import set_tensor_by_indices_to_value
|
||||
|
||||
from ..test_modeling_tf_common import ids_tensor
|
||||
|
||||
@@ -112,9 +111,9 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
|
||||
|
||||
mask = tf.cast(tf.constant([[1] + 9 * [0], 10 * [0]]), tf.bool)
|
||||
scores = set_tensor_by_indices_to_value(scores, mask, -1 / vocab_size)
|
||||
scores = tf.where(mask, -1 / vocab_size, scores)
|
||||
mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool)
|
||||
scores = set_tensor_by_indices_to_value(scores, mask, 4 / vocab_size)
|
||||
scores = tf.where(mask, 4 / vocab_size, scores)
|
||||
|
||||
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
|
||||
@@ -340,8 +339,8 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
|
||||
|
||||
# remove inf
|
||||
scores = set_tensor_by_indices_to_value(scores, tf.math.is_inf(scores), -1e9)
|
||||
scores_comp = set_tensor_by_indices_to_value(scores_comp, tf.math.is_inf(scores_comp), -1e9)
|
||||
scores = tf.where(tf.math.is_inf(scores), -1e9, scores)
|
||||
scores_comp = tf.where(tf.math.is_inf(scores_comp), -1e9, scores_comp)
|
||||
|
||||
# scores should be equal
|
||||
tf.debugging.assert_near(scores, scores_comp, atol=1e-3)
|
||||
|
||||
Reference in New Issue
Block a user