fix: multilingual midel convert to tflite get wrong token (#32079)
* fix: multilingual midel convert to tflite get wrong token * fix: modify test_force_tokens_logits_processor the checking value as scores.dtype.min --------- Co-authored-by: kent.sc.hung <kent.sc.hung@benq.com> Co-authored-by: Aya <[kent831217@gmail.com]>
This commit is contained in:
@@ -581,7 +581,7 @@ class TFForceTokensLogitsProcessor(TFLogitsProcessor):
|
|||||||
batch_size = scores.shape[0]
|
batch_size = scores.shape[0]
|
||||||
current_token = self.force_token_array[generation_idx]
|
current_token = self.force_token_array[generation_idx]
|
||||||
|
|
||||||
new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float("inf")
|
new_scores = tf.zeros_like(scores, dtype=scores.dtype) + tf.constant([scores.dtype.min])
|
||||||
indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1)
|
indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1)
|
||||||
updates = tf.zeros((batch_size,), dtype=scores.dtype)
|
updates = tf.zeros((batch_size,), dtype=scores.dtype)
|
||||||
new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates)
|
new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates)
|
||||||
|
|||||||
@@ -406,7 +406,12 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
|||||||
|
|
||||||
non_forced_inds = [i for i in range(vocab_size) if i != force_token_map[cur_len]]
|
non_forced_inds = [i for i in range(vocab_size) if i != force_token_map[cur_len]]
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, [non_forced_inds], axis=1))),
|
tf.math.reduce_all(
|
||||||
|
tf.experimental.numpy.isclose(
|
||||||
|
tf.gather(scores, [non_forced_inds], axis=1),
|
||||||
|
tf.constant(scores.dtype.min),
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# check that if the cur_len is not contained in the force_token_map, the logits are not modified
|
# check that if the cur_len is not contained in the force_token_map, the logits are not modified
|
||||||
|
|||||||
Reference in New Issue
Block a user