Add test for proper TF input signatures (#24320)
* Add test for proper input signatures * No more signature pruning * Test the dummy inputs are valid too * fine-tine -> fine-tune * Fix indent in test_dataset_conversion
This commit is contained in:
@@ -1065,6 +1065,16 @@ class TFModelTesterMixin:
|
||||
output_for_kw_input = model(**inputs_np)
|
||||
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
|
||||
|
||||
def test_valid_input_signature_and_dummies(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
call_args = inspect.signature(model.call).parameters
|
||||
for key in model.input_signature:
|
||||
self.assertIn(key, call_args)
|
||||
for key in model.dummy_inputs:
|
||||
self.assertIn(key, call_args)
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
# TODO (joao): after the embeddings refactor is complete, rework this test so as to rely exclusively on
|
||||
# tf.keras.layers.Embedding
|
||||
@@ -1700,7 +1710,7 @@ class TFModelTesterMixin:
|
||||
for tensor in test_batch.values():
|
||||
self.assertTrue(isinstance(tensor, tf.Tensor))
|
||||
self.assertEqual(len(tensor), len(input_dataset)) # Assert we didn't lose any data
|
||||
model(test_batch, training=False)
|
||||
model(test_batch, training=False)
|
||||
|
||||
if "labels" in inspect.signature(model_class.call).parameters.keys():
|
||||
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
|
||||
@@ -217,18 +217,17 @@ class TFCoreModelTesterMixin:
|
||||
for model_class in self.all_model_classes:
|
||||
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
class_sig = model._prune_signature(model.input_signature)
|
||||
num_out = len(model(class_inputs_dict))
|
||||
|
||||
for key in list(class_inputs_dict.keys()):
|
||||
# Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them
|
||||
if key not in class_sig:
|
||||
if key not in model.input_signature:
|
||||
del class_inputs_dict[key]
|
||||
# Check it's a tensor, in case the inputs dict has some bools in it too
|
||||
elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
|
||||
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
|
||||
|
||||
if set(class_inputs_dict.keys()) != set(class_sig.keys()):
|
||||
if set(class_inputs_dict.keys()) != set(model.input_signature.keys()):
|
||||
continue # Some models have inputs that the preparation functions don't create, we skip those
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
||||
Reference in New Issue
Block a user