Enabling custom TF signature draft (#19249)
* Custom TF signature draft * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Adding tf signature tests * Fixing signature check and adding asserts * fixing model load path * Adjusting signature tests * Formatting file Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Dimitre Oliveira <dimitreoliveira@Dimitres-MacBook-Air.local>
This commit is contained in:
@@ -2216,6 +2216,46 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
def test_save_pretrained_signatures(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Short custom TF signature function.
|
||||
# `input_signature` is specific to BERT.
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
[
|
||||
tf.TensorSpec([None, None], tf.int32, name="input_ids"),
|
||||
tf.TensorSpec([None, None], tf.int32, name="token_type_ids"),
|
||||
tf.TensorSpec([None, None], tf.int32, name="attention_mask"),
|
||||
]
|
||||
]
|
||||
)
|
||||
def serving_fn(input):
|
||||
return model(input)
|
||||
|
||||
# Using default signature (default behavior) overrides 'serving_default'
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, saved_model=True, signatures=None)
|
||||
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
|
||||
self.assertTrue("serving_default" in list(model_loaded.signatures.keys()))
|
||||
|
||||
# Providing custom signature function
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, saved_model=True, signatures={"custom_signature": serving_fn})
|
||||
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
|
||||
self.assertTrue("custom_signature" in list(model_loaded.signatures.keys()))
|
||||
|
||||
# Providing multiple custom signature function
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(
|
||||
tmp_dir,
|
||||
saved_model=True,
|
||||
signatures={"custom_signature_1": serving_fn, "custom_signature_2": serving_fn},
|
||||
)
|
||||
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
|
||||
self.assertTrue("custom_signature_1" in list(model_loaded.signatures.keys()))
|
||||
self.assertTrue("custom_signature_2" in list(model_loaded.signatures.keys()))
|
||||
|
||||
|
||||
@require_tf
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user