Generate: Export TF generate with a TF tokenizer (#22310)
* Export TF generate with a TF tokenizer * remove unused lines
This commit is contained in:
@@ -1725,14 +1725,13 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
# 2-to-n generation steps can then be run in autoregressive fashion
|
# 2-to-n generation steps can then be run in autoregressive fashion
|
||||||
# only in case 1st generation step does NOT yield EOS token though
|
# only in case 1st generation step does NOT yield EOS token though
|
||||||
if greedy_search_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
|
maximum_iterations = max_length - cur_len
|
||||||
maximum_iterations = max_length - cur_len
|
generated, _, cur_len, _ = tf.while_loop(
|
||||||
generated, _, cur_len, _ = tf.while_loop(
|
greedy_search_cond_fn,
|
||||||
greedy_search_cond_fn,
|
greedy_search_body_fn,
|
||||||
greedy_search_body_fn,
|
(generated, finished_sequences, cur_len, model_kwargs),
|
||||||
(generated, finished_sequences, cur_len, model_kwargs),
|
maximum_iterations=maximum_iterations,
|
||||||
maximum_iterations=maximum_iterations,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# 6. prepare outputs
|
# 6. prepare outputs
|
||||||
if not use_xla:
|
if not use_xla:
|
||||||
@@ -2016,14 +2015,13 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
# 2-to-n generation steps can then be run in autoregressive fashion
|
# 2-to-n generation steps can then be run in autoregressive fashion
|
||||||
# only in case 1st generation step does NOT yield EOS token though
|
# only in case 1st generation step does NOT yield EOS token though
|
||||||
if sample_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
|
maximum_iterations = max_length - cur_len
|
||||||
maximum_iterations = max_length - cur_len
|
generated, _, cur_len, _ = tf.while_loop(
|
||||||
generated, _, cur_len, _ = tf.while_loop(
|
sample_cond_fn,
|
||||||
sample_cond_fn,
|
sample_body_fn,
|
||||||
sample_body_fn,
|
(generated, finished_sequences, cur_len, model_kwargs),
|
||||||
(generated, finished_sequences, cur_len, model_kwargs),
|
maximum_iterations=maximum_iterations,
|
||||||
maximum_iterations=maximum_iterations,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# 6. prepare outputs
|
# 6. prepare outputs
|
||||||
if not use_xla:
|
if not use_xla:
|
||||||
@@ -2565,7 +2563,8 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
|
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
|
||||||
# NOT yield EOS token though)
|
# NOT yield EOS token though)
|
||||||
if beam_search_cond_fn(
|
maximum_iterations = max_length - cur_len
|
||||||
|
(
|
||||||
cur_len,
|
cur_len,
|
||||||
running_sequences,
|
running_sequences,
|
||||||
running_scores,
|
running_scores,
|
||||||
@@ -2574,9 +2573,10 @@ class TFGenerationMixin:
|
|||||||
scores,
|
scores,
|
||||||
beam_indices,
|
beam_indices,
|
||||||
is_sent_finished,
|
is_sent_finished,
|
||||||
model_kwargs,
|
_,
|
||||||
):
|
) = tf.while_loop(
|
||||||
maximum_iterations = max_length - cur_len
|
beam_search_cond_fn,
|
||||||
|
beam_search_body_fn,
|
||||||
(
|
(
|
||||||
cur_len,
|
cur_len,
|
||||||
running_sequences,
|
running_sequences,
|
||||||
@@ -2586,23 +2586,10 @@ class TFGenerationMixin:
|
|||||||
scores,
|
scores,
|
||||||
beam_indices,
|
beam_indices,
|
||||||
is_sent_finished,
|
is_sent_finished,
|
||||||
_,
|
model_kwargs,
|
||||||
) = tf.while_loop(
|
),
|
||||||
beam_search_cond_fn,
|
maximum_iterations=maximum_iterations,
|
||||||
beam_search_body_fn,
|
)
|
||||||
(
|
|
||||||
cur_len,
|
|
||||||
running_sequences,
|
|
||||||
running_scores,
|
|
||||||
running_beam_indices,
|
|
||||||
sequences,
|
|
||||||
scores,
|
|
||||||
beam_indices,
|
|
||||||
is_sent_finished,
|
|
||||||
model_kwargs,
|
|
||||||
),
|
|
||||||
maximum_iterations=maximum_iterations,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 6. prepare outputs
|
# 6. prepare outputs
|
||||||
# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
|
# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
|
||||||
@@ -3019,22 +3006,13 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
# 2-to-n generation steps can then be run in autoregressive fashion
|
# 2-to-n generation steps can then be run in autoregressive fashion
|
||||||
# only in case 1st generation step does NOT yield EOS token though
|
# only in case 1st generation step does NOT yield EOS token though
|
||||||
if contrastive_search_cond_fn(
|
maximum_iterations = max_length - cur_len
|
||||||
generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables
|
generated, _, cur_len, _, _ = tf.while_loop(
|
||||||
):
|
contrastive_search_cond_fn,
|
||||||
maximum_iterations = max_length - cur_len
|
contrastive_search_body_fn,
|
||||||
(
|
(generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables),
|
||||||
generated,
|
maximum_iterations=maximum_iterations,
|
||||||
_,
|
)
|
||||||
cur_len,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
) = tf.while_loop(
|
|
||||||
contrastive_search_cond_fn,
|
|
||||||
contrastive_search_body_fn,
|
|
||||||
(generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables),
|
|
||||||
maximum_iterations=maximum_iterations,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 6. prepare outputs
|
# 6. prepare outputs
|
||||||
if not use_xla:
|
if not use_xla:
|
||||||
|
|||||||
@@ -13,13 +13,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
from transformers import is_tf_available
|
from transformers import is_tensorflow_text_available, is_tf_available
|
||||||
from transformers.testing_utils import require_tf, slow
|
from transformers.testing_utils import require_tensorflow_text, require_tf, slow
|
||||||
|
|
||||||
from ..test_modeling_tf_common import floats_tensor
|
from ..test_modeling_tf_common import floats_tensor
|
||||||
from .test_framework_agnostic import GenerationIntegrationTestsMixin
|
from .test_framework_agnostic import GenerationIntegrationTestsMixin
|
||||||
@@ -40,6 +42,9 @@ if is_tf_available():
|
|||||||
tf_top_k_top_p_filtering,
|
tf_top_k_top_p_filtering,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_tensorflow_text_available():
|
||||||
|
import tensorflow_text as text
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class UtilsFunctionsTest(unittest.TestCase):
|
class UtilsFunctionsTest(unittest.TestCase):
|
||||||
@@ -239,6 +244,36 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
|||||||
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||||
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
|
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_tensorflow_text
|
||||||
|
def test_generate_tf_function_export_with_tf_tokenizer(self):
|
||||||
|
# TF-only test: tf.saved_model export
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# file needed to load the TF tokenizer
|
||||||
|
hf_hub_download(repo_id="google/flan-t5-small", filename="spiece.model", local_dir=tmp_dir)
|
||||||
|
|
||||||
|
class CompleteSentenceTransformer(tf.keras.layers.Layer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.tokenizer = text.SentencepieceTokenizer(
|
||||||
|
model=tf.io.gfile.GFile(os.path.join(tmp_dir, "spiece.model"), "rb").read()
|
||||||
|
)
|
||||||
|
self.model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||||
|
|
||||||
|
def call(self, inputs, *args, **kwargs):
|
||||||
|
tokens = self.tokenizer.tokenize(inputs)
|
||||||
|
input_ids, attention_mask = text.pad_model_inputs(
|
||||||
|
tokens, max_seq_length=64, pad_value=self.model.config.pad_token_id
|
||||||
|
)
|
||||||
|
outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
return self.tokenizer.detokenize(outputs)
|
||||||
|
|
||||||
|
complete_model = CompleteSentenceTransformer()
|
||||||
|
inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string, name="inputs")
|
||||||
|
outputs = complete_model(inputs)
|
||||||
|
keras_model = tf.keras.Model(inputs, outputs)
|
||||||
|
keras_model.save(tmp_dir)
|
||||||
|
|
||||||
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
|
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
|
||||||
# Has PT equivalent: this test relies on random sampling
|
# Has PT equivalent: this test relies on random sampling
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
|
|||||||
Reference in New Issue
Block a user