Generate: Export TF generate with a TF tokenizer (#22310)
* Export TF generate with a TF tokenizer * remove unused lines
This commit is contained in:
@@ -1725,14 +1725,13 @@ class TFGenerationMixin:
|
||||
|
||||
# 2-to-n generation steps can then be run in autoregressive fashion
|
||||
# only in case 1st generation step does NOT yield EOS token though
|
||||
if greedy_search_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
|
||||
maximum_iterations = max_length - cur_len
|
||||
generated, _, cur_len, _ = tf.while_loop(
|
||||
greedy_search_cond_fn,
|
||||
greedy_search_body_fn,
|
||||
(generated, finished_sequences, cur_len, model_kwargs),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
maximum_iterations = max_length - cur_len
|
||||
generated, _, cur_len, _ = tf.while_loop(
|
||||
greedy_search_cond_fn,
|
||||
greedy_search_body_fn,
|
||||
(generated, finished_sequences, cur_len, model_kwargs),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
|
||||
# 6. prepare outputs
|
||||
if not use_xla:
|
||||
@@ -2016,14 +2015,13 @@ class TFGenerationMixin:
|
||||
|
||||
# 2-to-n generation steps can then be run in autoregressive fashion
|
||||
# only in case 1st generation step does NOT yield EOS token though
|
||||
if sample_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
|
||||
maximum_iterations = max_length - cur_len
|
||||
generated, _, cur_len, _ = tf.while_loop(
|
||||
sample_cond_fn,
|
||||
sample_body_fn,
|
||||
(generated, finished_sequences, cur_len, model_kwargs),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
maximum_iterations = max_length - cur_len
|
||||
generated, _, cur_len, _ = tf.while_loop(
|
||||
sample_cond_fn,
|
||||
sample_body_fn,
|
||||
(generated, finished_sequences, cur_len, model_kwargs),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
|
||||
# 6. prepare outputs
|
||||
if not use_xla:
|
||||
@@ -2565,7 +2563,8 @@ class TFGenerationMixin:
|
||||
|
||||
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
|
||||
# NOT yield EOS token though)
|
||||
if beam_search_cond_fn(
|
||||
maximum_iterations = max_length - cur_len
|
||||
(
|
||||
cur_len,
|
||||
running_sequences,
|
||||
running_scores,
|
||||
@@ -2574,9 +2573,10 @@ class TFGenerationMixin:
|
||||
scores,
|
||||
beam_indices,
|
||||
is_sent_finished,
|
||||
model_kwargs,
|
||||
):
|
||||
maximum_iterations = max_length - cur_len
|
||||
_,
|
||||
) = tf.while_loop(
|
||||
beam_search_cond_fn,
|
||||
beam_search_body_fn,
|
||||
(
|
||||
cur_len,
|
||||
running_sequences,
|
||||
@@ -2586,23 +2586,10 @@ class TFGenerationMixin:
|
||||
scores,
|
||||
beam_indices,
|
||||
is_sent_finished,
|
||||
_,
|
||||
) = tf.while_loop(
|
||||
beam_search_cond_fn,
|
||||
beam_search_body_fn,
|
||||
(
|
||||
cur_len,
|
||||
running_sequences,
|
||||
running_scores,
|
||||
running_beam_indices,
|
||||
sequences,
|
||||
scores,
|
||||
beam_indices,
|
||||
is_sent_finished,
|
||||
model_kwargs,
|
||||
),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
model_kwargs,
|
||||
),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
|
||||
# 6. prepare outputs
|
||||
# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
|
||||
@@ -3019,22 +3006,13 @@ class TFGenerationMixin:
|
||||
|
||||
# 2-to-n generation steps can then be run in autoregressive fashion
|
||||
# only in case 1st generation step does NOT yield EOS token though
|
||||
if contrastive_search_cond_fn(
|
||||
generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables
|
||||
):
|
||||
maximum_iterations = max_length - cur_len
|
||||
(
|
||||
generated,
|
||||
_,
|
||||
cur_len,
|
||||
_,
|
||||
_,
|
||||
) = tf.while_loop(
|
||||
contrastive_search_cond_fn,
|
||||
contrastive_search_body_fn,
|
||||
(generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
maximum_iterations = max_length - cur_len
|
||||
generated, _, cur_len, _, _ = tf.while_loop(
|
||||
contrastive_search_cond_fn,
|
||||
contrastive_search_body_fn,
|
||||
(generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
|
||||
# 6. prepare outputs
|
||||
if not use_xla:
|
||||
|
||||
@@ -13,13 +13,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
from transformers import is_tensorflow_text_available, is_tf_available
|
||||
from transformers.testing_utils import require_tensorflow_text, require_tf, slow
|
||||
|
||||
from ..test_modeling_tf_common import floats_tensor
|
||||
from .test_framework_agnostic import GenerationIntegrationTestsMixin
|
||||
@@ -40,6 +42,9 @@ if is_tf_available():
|
||||
tf_top_k_top_p_filtering,
|
||||
)
|
||||
|
||||
if is_tensorflow_text_available():
|
||||
import tensorflow_text as text
|
||||
|
||||
|
||||
@require_tf
|
||||
class UtilsFunctionsTest(unittest.TestCase):
|
||||
@@ -239,6 +244,36 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
|
||||
|
||||
@slow
|
||||
@require_tensorflow_text
|
||||
def test_generate_tf_function_export_with_tf_tokenizer(self):
|
||||
# TF-only test: tf.saved_model export
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# file needed to load the TF tokenizer
|
||||
hf_hub_download(repo_id="google/flan-t5-small", filename="spiece.model", local_dir=tmp_dir)
|
||||
|
||||
class CompleteSentenceTransformer(tf.keras.layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tokenizer = text.SentencepieceTokenizer(
|
||||
model=tf.io.gfile.GFile(os.path.join(tmp_dir, "spiece.model"), "rb").read()
|
||||
)
|
||||
self.model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
def call(self, inputs, *args, **kwargs):
|
||||
tokens = self.tokenizer.tokenize(inputs)
|
||||
input_ids, attention_mask = text.pad_model_inputs(
|
||||
tokens, max_seq_length=64, pad_value=self.model.config.pad_token_id
|
||||
)
|
||||
outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return self.tokenizer.detokenize(outputs)
|
||||
|
||||
complete_model = CompleteSentenceTransformer()
|
||||
inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string, name="inputs")
|
||||
outputs = complete_model(inputs)
|
||||
keras_model = tf.keras.Model(inputs, outputs)
|
||||
keras_model.save(tmp_dir)
|
||||
|
||||
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
|
||||
# Has PT equivalent: this test relies on random sampling
|
||||
generation_kwargs = {
|
||||
|
||||
Reference in New Issue
Block a user