change shape to support dynamic batch input in tf.function XLA generate for tf serving (#18372)
* change shape to support dynamic batch input in tf.generate * add tests Co-authored-by: nlpcatcode <nlpcodecat@gmail.com>
This commit is contained in:
@@ -73,6 +73,7 @@ if is_tf_available():
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
BertConfig,
|
||||
TFAutoModel,
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFBertModel,
|
||||
TFSharedEmbeddings,
|
||||
@@ -2163,6 +2164,46 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
def test_generate_tf_function_export(self):
|
||||
test_model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
max_length = 8
|
||||
|
||||
class DummyModel(tf.Module):
|
||||
def __init__(self, model):
|
||||
super(DummyModel, self).__init__()
|
||||
self.model = model
|
||||
|
||||
@tf.function(
|
||||
input_signature=(
|
||||
tf.TensorSpec((None, max_length), tf.int32, name="input_ids"),
|
||||
tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"),
|
||||
),
|
||||
jit_compile=True,
|
||||
)
|
||||
def serving(self, input_ids, attention_mask):
|
||||
outputs = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=max_length,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
return {"sequences": outputs["sequences"]}
|
||||
|
||||
dummy_input_ids = [[2, 3, 4, 1, 0, 0, 0, 0], [102, 103, 104, 105, 1, 0, 0, 0]]
|
||||
dummy_attention_masks = [[1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0, 0]]
|
||||
dummy_model = DummyModel(model=test_model)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
|
||||
serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"]
|
||||
for batch_size in range(1, len(dummy_input_ids) + 1):
|
||||
inputs = {
|
||||
"input_ids": tf.constant(dummy_input_ids[:batch_size]),
|
||||
"attention_mask": tf.constant(dummy_attention_masks[:batch_size]),
|
||||
}
|
||||
tf_func_outputs = serving_func(**inputs)["sequences"]
|
||||
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length)
|
||||
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
|
||||
|
||||
|
||||
@require_tf
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user