From d7f7f29f29f7267ad895514e3a5054b35091d152 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 12 Apr 2022 17:51:47 +0100 Subject: [PATCH] TF: remove set_tensor_by_indices_to_value (#16729) --- .../generation_tf_logits_process.py | 11 +++-------- src/transformers/generation_tf_utils.py | 17 ++++++++--------- src/transformers/tf_utils.py | 5 ----- .../test_generation_tf_logits_process.py | 9 ++++----- 4 files changed, 15 insertions(+), 27 deletions(-) diff --git a/src/transformers/generation_tf_logits_process.py b/src/transformers/generation_tf_logits_process.py index 92bae58eb8..48df31a3ff 100644 --- a/src/transformers/generation_tf_logits_process.py +++ b/src/transformers/generation_tf_logits_process.py @@ -19,7 +19,6 @@ from typing import List import numpy as np import tensorflow as tf -from .tf_utils import set_tensor_by_indices_to_value from .utils import add_start_docstrings from .utils.logging import get_logger @@ -221,7 +220,7 @@ class TFMinLengthLogitsProcessor(TFLogitsProcessor): # generate is not XLA - compileable anyways if cur_len < self.min_length: eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape) - scores = set_tensor_by_indices_to_value(scores, eos_token_id_mask, float("-inf")) + scores = tf.where(eos_token_id_mask, float("-inf"), scores) return scores @@ -339,9 +338,7 @@ class TFNoBadWordsLogitsProcessor(TFLogitsProcessor): [True if token in banned_tokens_slice else False for token in range(vocab_size)] ) - scores = set_tensor_by_indices_to_value( - scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") - ) + scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores) return scores @@ -397,9 +394,7 @@ class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor): [True if token in banned_tokens_slice else False for token in range(vocab_size)] ) - scores = set_tensor_by_indices_to_value( - scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") - ) + scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores) return scores diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index b9ae5bd77b..83ba8a84cd 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -34,7 +34,7 @@ from .generation_tf_logits_process import ( TFTopKLogitsWarper, TFTopPLogitsWarper, ) -from .tf_utils import set_tensor_by_indices_to_value, shape_list +from .tf_utils import shape_list from .utils import ModelOutput, logging @@ -952,8 +952,7 @@ class TFGenerationMixin: [True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool ) eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size]) - - scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf")) + scores = tf.where(eos_token_indices_mask, -float("inf"), scores) if no_repeat_ngram_size > 0: # calculate a list of banned tokens to prevent repetitively generating the same ngrams @@ -969,8 +968,8 @@ class TFGenerationMixin: [True if token in banned_tokens_slice else False for token in range(vocab_size)] ) - scores = set_tensor_by_indices_to_value( - scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") + scores = tf.where( + tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores ) if bad_words_ids is not None: @@ -983,8 +982,8 @@ class TFGenerationMixin: [True if token in banned_tokens_slice else False for token in range(vocab_size)] ) - scores = set_tensor_by_indices_to_value( - scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") + scores = tf.where( + tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores ) assert shape_list(scores) == [batch_size * num_beams, vocab_size] @@ -2950,7 +2949,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None] - logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value) + logits = tf.where(indices_to_remove, filter_value, logits) if top_p < 1.0: sorted_indices = tf.argsort(logits, direction="DESCENDING") sorted_logits = tf.gather( @@ -2979,7 +2978,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In ) # scatter sorted tensors to original indexing indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices) - logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value) + logits = tf.where(indices_to_remove, filter_value, logits) return logits diff --git a/src/transformers/tf_utils.py b/src/transformers/tf_utils.py index 42c744be7a..c0d076b31c 100644 --- a/src/transformers/tf_utils.py +++ b/src/transformers/tf_utils.py @@ -23,11 +23,6 @@ from .utils import logging logger = logging.get_logger(__name__) -def set_tensor_by_indices_to_value(tensor: tf.Tensor, indices: tf.Tensor, value: Union[tf.Tensor, int, float]): - # create value_tensor since tensor value assignment is not possible in TF - return tf.where(indices, value, tensor) - - def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]: """ Deal with dynamic shape in tensorflow cleanly. diff --git a/tests/generation/test_generation_tf_logits_process.py b/tests/generation/test_generation_tf_logits_process.py index 8a5ee05368..60fa3352e5 100644 --- a/tests/generation/test_generation_tf_logits_process.py +++ b/tests/generation/test_generation_tf_logits_process.py @@ -37,7 +37,6 @@ if is_tf_available(): TFTopKLogitsWarper, TFTopPLogitsWarper, ) - from transformers.tf_utils import set_tensor_by_indices_to_value from ..test_modeling_tf_common import ids_tensor @@ -112,9 +111,9 @@ class TFLogitsProcessorTest(unittest.TestCase): scores = self._get_uniform_logits(batch_size=2, length=vocab_size) mask = tf.cast(tf.constant([[1] + 9 * [0], 10 * [0]]), tf.bool) - scores = set_tensor_by_indices_to_value(scores, mask, -1 / vocab_size) + scores = tf.where(mask, -1 / vocab_size, scores) mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool) - scores = set_tensor_by_indices_to_value(scores, mask, 4 / vocab_size) + scores = tf.where(mask, 4 / vocab_size, scores) rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) @@ -340,8 +339,8 @@ class TFLogitsProcessorTest(unittest.TestCase): scores_comp = processor(input_ids, scores_comp, cur_len=cur_len) # remove inf - scores = set_tensor_by_indices_to_value(scores, tf.math.is_inf(scores), -1e9) - scores_comp = set_tensor_by_indices_to_value(scores_comp, tf.math.is_inf(scores_comp), -1e9) + scores = tf.where(tf.math.is_inf(scores), -1e9, scores) + scores_comp = tf.where(tf.math.is_inf(scores_comp), -1e9, scores_comp) # scores should be equal tf.debugging.assert_near(scores, scores_comp, atol=1e-3)