Generate: TF supports multiple eos tokens (#21571)
This commit is contained in:
@@ -19,6 +19,7 @@ import unittest
|
||||
from transformers import is_tf_available
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_modeling_tf_common import floats_tensor
|
||||
from .test_framework_agnostic import GenerationIntegrationTestsMixin
|
||||
|
||||
|
||||
@@ -26,8 +27,11 @@ if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
TFAutoModelForCausalLM,
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFAutoModelForSpeechSeq2Seq,
|
||||
TFAutoModelForVision2Seq,
|
||||
TFLogitsProcessorList,
|
||||
TFMinLengthLogitsProcessor,
|
||||
tf_top_k_top_p_filtering,
|
||||
@@ -136,15 +140,19 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
if is_tf_available():
|
||||
framework_dependent_parameters = {
|
||||
"AutoModelForCausalLM": TFAutoModelForCausalLM,
|
||||
"AutoModelForSpeechSeq2Seq": TFAutoModelForSpeechSeq2Seq,
|
||||
"AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM,
|
||||
"AutoModelForVision2Seq": TFAutoModelForVision2Seq,
|
||||
"LogitsProcessorList": TFLogitsProcessorList,
|
||||
"MinLengthLogitsProcessor": TFMinLengthLogitsProcessor,
|
||||
"create_tensor_fn": tf.convert_to_tensor,
|
||||
"floats_tensor": floats_tensor,
|
||||
"return_tensors": "tf",
|
||||
}
|
||||
|
||||
@slow
|
||||
def test_generate_tf_function_export_fixed_input_length(self):
|
||||
# TF-only test: tf.saved_model export
|
||||
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
input_length = 2
|
||||
max_new_tokens = 2
|
||||
@@ -187,6 +195,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
|
||||
@slow
|
||||
def test_generate_tf_function_export_fixed_batch_size(self):
|
||||
# TF-only test: tf.saved_model export
|
||||
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
batch_size = 1
|
||||
max_new_tokens = 2
|
||||
@@ -226,3 +235,32 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
tf_func_outputs = serving_func(**inputs)["sequences"]
|
||||
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
|
||||
|
||||
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
|
||||
# Has PT equivalent: this test relies on random sampling
|
||||
generation_kwargs = {
|
||||
"do_sample": True,
|
||||
"num_beams": 1,
|
||||
"top_p": 0.7,
|
||||
"top_k": 10,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
expectation = 14
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors="tf")
|
||||
model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
|
||||
eos_token_id = 638
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
tf.random.set_seed(0)
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
eos_token_id = [638, 198]
|
||||
with tf.device(":/CPU:0"):
|
||||
tf.random.set_seed(0)
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
Reference in New Issue
Block a user