Generate: Export TF generate with a TF tokenizer (#22310)

* Export TF generate with a TF tokenizer

* remove unused lines
This commit is contained in:
Joao Gante
2023-03-22 15:00:20 +00:00
committed by GitHub
parent 5fd4e3c87c
commit 12febc20db
2 changed files with 68 additions and 55 deletions

View File

@@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
from huggingface_hub import hf_hub_download
from transformers import is_tf_available
from transformers.testing_utils import require_tf, slow
from transformers import is_tensorflow_text_available, is_tf_available
from transformers.testing_utils import require_tensorflow_text, require_tf, slow
from ..test_modeling_tf_common import floats_tensor
from .test_framework_agnostic import GenerationIntegrationTestsMixin
@@ -40,6 +42,9 @@ if is_tf_available():
tf_top_k_top_p_filtering,
)
if is_tensorflow_text_available():
import tensorflow_text as text
@require_tf
class UtilsFunctionsTest(unittest.TestCase):
@@ -239,6 +244,36 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
@slow
@require_tensorflow_text
def test_generate_tf_function_export_with_tf_tokenizer(self):
# TF-only test: tf.saved_model export
with tempfile.TemporaryDirectory() as tmp_dir:
# file needed to load the TF tokenizer
hf_hub_download(repo_id="google/flan-t5-small", filename="spiece.model", local_dir=tmp_dir)
class CompleteSentenceTransformer(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
self.tokenizer = text.SentencepieceTokenizer(
model=tf.io.gfile.GFile(os.path.join(tmp_dir, "spiece.model"), "rb").read()
)
self.model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
def call(self, inputs, *args, **kwargs):
tokens = self.tokenizer.tokenize(inputs)
input_ids, attention_mask = text.pad_model_inputs(
tokens, max_seq_length=64, pad_value=self.model.config.pad_token_id
)
outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask)
return self.tokenizer.detokenize(outputs)
complete_model = CompleteSentenceTransformer()
inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string, name="inputs")
outputs = complete_model(inputs)
keras_model = tf.keras.Model(inputs, outputs)
keras_model.save(tmp_dir)
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
# Has PT equivalent: this test relies on random sampling
generation_kwargs = {