Generate: TF .generate() can now be exported with dynamic length (#21474)

This commit is contained in:
Joao Gante
2023-02-09 12:52:30 +00:00
committed by GitHub
parent e69f9715eb
commit 2edf9a857b
3 changed files with 68 additions and 20 deletions

View File

@@ -144,9 +144,10 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
}
@slow
def test_generate_tf_function_export(self):
def test_generate_tf_function_export_fixed_input_length(self):
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
max_length = 2
input_length = 2
max_new_tokens = 2
class DummyModel(tf.Module):
def __init__(self, model):
@@ -155,8 +156,8 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
@tf.function(
input_signature=(
tf.TensorSpec((None, max_length), tf.int32, name="input_ids"),
tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"),
tf.TensorSpec((None, input_length), tf.int32, name="input_ids"),
tf.TensorSpec((None, input_length), tf.int32, name="attention_mask"),
),
jit_compile=True,
)
@@ -164,7 +165,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
outputs = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_length,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
)
return {"sequences": outputs["sequences"]}
@@ -181,5 +182,47 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
"attention_mask": tf.constant(dummy_attention_masks[:batch_size]),
}
tf_func_outputs = serving_func(**inputs)["sequences"]
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length)
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
@slow
def test_generate_tf_function_export_fixed_batch_size(self):
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
batch_size = 1
max_new_tokens = 2
class DummyModel(tf.Module):
def __init__(self, model):
super(DummyModel, self).__init__()
self.model = model
@tf.function(
input_signature=(
tf.TensorSpec((batch_size, None), tf.int32, name="input_ids"),
tf.TensorSpec((batch_size, None), tf.int32, name="attention_mask"),
),
jit_compile=True,
)
def serving(self, input_ids, attention_mask):
outputs = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
)
return {"sequences": outputs["sequences"]}
dummy_input_ids = [[2], [102, 103]]
dummy_attention_masks = [[1], [1, 1]]
dummy_model = DummyModel(model=test_model)
with tempfile.TemporaryDirectory() as tmp_dir:
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"]
for input_row in range(len(dummy_input_ids)):
inputs = {
"input_ids": tf.constant([dummy_input_ids[input_row]]),
"attention_mask": tf.constant([dummy_attention_masks[input_row]]),
}
tf_func_outputs = serving_func(**inputs)["sequences"]
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)