From 12febc20dbb5e93afe9b9f509dfdc12c5c800c6a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 22 Mar 2023 15:00:20 +0000 Subject: [PATCH] Generate: Export TF generate with a TF tokenizer (#22310) * Export TF generate with a TF tokenizer * remove unused lines --- src/transformers/generation/tf_utils.py | 84 +++++++++---------------- tests/generation/test_tf_utils.py | 39 +++++++++++- 2 files changed, 68 insertions(+), 55 deletions(-) diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 4a9140f885..749c07d547 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -1725,14 +1725,13 @@ class TFGenerationMixin: # 2-to-n generation steps can then be run in autoregressive fashion # only in case 1st generation step does NOT yield EOS token though - if greedy_search_cond_fn(generated, finished_sequences, cur_len, model_kwargs): - maximum_iterations = max_length - cur_len - generated, _, cur_len, _ = tf.while_loop( - greedy_search_cond_fn, - greedy_search_body_fn, - (generated, finished_sequences, cur_len, model_kwargs), - maximum_iterations=maximum_iterations, - ) + maximum_iterations = max_length - cur_len + generated, _, cur_len, _ = tf.while_loop( + greedy_search_cond_fn, + greedy_search_body_fn, + (generated, finished_sequences, cur_len, model_kwargs), + maximum_iterations=maximum_iterations, + ) # 6. prepare outputs if not use_xla: @@ -2016,14 +2015,13 @@ class TFGenerationMixin: # 2-to-n generation steps can then be run in autoregressive fashion # only in case 1st generation step does NOT yield EOS token though - if sample_cond_fn(generated, finished_sequences, cur_len, model_kwargs): - maximum_iterations = max_length - cur_len - generated, _, cur_len, _ = tf.while_loop( - sample_cond_fn, - sample_body_fn, - (generated, finished_sequences, cur_len, model_kwargs), - maximum_iterations=maximum_iterations, - ) + maximum_iterations = max_length - cur_len + generated, _, cur_len, _ = tf.while_loop( + sample_cond_fn, + sample_body_fn, + (generated, finished_sequences, cur_len, model_kwargs), + maximum_iterations=maximum_iterations, + ) # 6. prepare outputs if not use_xla: @@ -2565,7 +2563,8 @@ class TFGenerationMixin: # 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does # NOT yield EOS token though) - if beam_search_cond_fn( + maximum_iterations = max_length - cur_len + ( cur_len, running_sequences, running_scores, @@ -2574,9 +2573,10 @@ class TFGenerationMixin: scores, beam_indices, is_sent_finished, - model_kwargs, - ): - maximum_iterations = max_length - cur_len + _, + ) = tf.while_loop( + beam_search_cond_fn, + beam_search_body_fn, ( cur_len, running_sequences, @@ -2586,23 +2586,10 @@ class TFGenerationMixin: scores, beam_indices, is_sent_finished, - _, - ) = tf.while_loop( - beam_search_cond_fn, - beam_search_body_fn, - ( - cur_len, - running_sequences, - running_scores, - running_beam_indices, - sequences, - scores, - beam_indices, - is_sent_finished, - model_kwargs, - ), - maximum_iterations=maximum_iterations, - ) + model_kwargs, + ), + maximum_iterations=maximum_iterations, + ) # 6. prepare outputs # Account for the edge-case where there are no finished sequences for a particular batch item. If so, return @@ -3019,22 +3006,13 @@ class TFGenerationMixin: # 2-to-n generation steps can then be run in autoregressive fashion # only in case 1st generation step does NOT yield EOS token though - if contrastive_search_cond_fn( - generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables - ): - maximum_iterations = max_length - cur_len - ( - generated, - _, - cur_len, - _, - _, - ) = tf.while_loop( - contrastive_search_cond_fn, - contrastive_search_body_fn, - (generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables), - maximum_iterations=maximum_iterations, - ) + maximum_iterations = max_length - cur_len + generated, _, cur_len, _, _ = tf.while_loop( + contrastive_search_cond_fn, + contrastive_search_body_fn, + (generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables), + maximum_iterations=maximum_iterations, + ) # 6. prepare outputs if not use_xla: diff --git a/tests/generation/test_tf_utils.py b/tests/generation/test_tf_utils.py index cab4512bec..6fdad1ef63 100644 --- a/tests/generation/test_tf_utils.py +++ b/tests/generation/test_tf_utils.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import tempfile import unittest import numpy as np +from huggingface_hub import hf_hub_download -from transformers import is_tf_available -from transformers.testing_utils import require_tf, slow +from transformers import is_tensorflow_text_available, is_tf_available +from transformers.testing_utils import require_tensorflow_text, require_tf, slow from ..test_modeling_tf_common import floats_tensor from .test_framework_agnostic import GenerationIntegrationTestsMixin @@ -40,6 +42,9 @@ if is_tf_available(): tf_top_k_top_p_filtering, ) +if is_tensorflow_text_available(): + import tensorflow_text as text + @require_tf class UtilsFunctionsTest(unittest.TestCase): @@ -239,6 +244,36 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens) tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs) + @slow + @require_tensorflow_text + def test_generate_tf_function_export_with_tf_tokenizer(self): + # TF-only test: tf.saved_model export + with tempfile.TemporaryDirectory() as tmp_dir: + # file needed to load the TF tokenizer + hf_hub_download(repo_id="google/flan-t5-small", filename="spiece.model", local_dir=tmp_dir) + + class CompleteSentenceTransformer(tf.keras.layers.Layer): + def __init__(self): + super().__init__() + self.tokenizer = text.SentencepieceTokenizer( + model=tf.io.gfile.GFile(os.path.join(tmp_dir, "spiece.model"), "rb").read() + ) + self.model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5") + + def call(self, inputs, *args, **kwargs): + tokens = self.tokenizer.tokenize(inputs) + input_ids, attention_mask = text.pad_model_inputs( + tokens, max_seq_length=64, pad_value=self.model.config.pad_token_id + ) + outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask) + return self.tokenizer.detokenize(outputs) + + complete_model = CompleteSentenceTransformer() + inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string, name="inputs") + outputs = complete_model(inputs) + keras_model = tf.keras.Model(inputs, outputs) + keras_model.save(tmp_dir) + def test_eos_token_id_int_and_list_top_k_top_sampling(self): # Has PT equivalent: this test relies on random sampling generation_kwargs = {