Add test for proper TF input signatures (#24320)

* Add test for proper input signatures

* No more signature pruning

* Test the dummy inputs are valid too

* fine-tine -> fine-tune

* Fix indent in test_dataset_conversion
This commit is contained in:
Matt
2023-06-16 17:03:13 +01:00
committed by GitHub
parent bdfd57d1d1
commit 9138995025
6 changed files with 21 additions and 19 deletions

View File

@@ -217,18 +217,17 @@ class TFCoreModelTesterMixin:
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
class_sig = model._prune_signature(model.input_signature)
num_out = len(model(class_inputs_dict))
for key in list(class_inputs_dict.keys()):
# Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them
if key not in class_sig:
if key not in model.input_signature:
del class_inputs_dict[key]
# Check it's a tensor, in case the inputs dict has some bools in it too
elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
if set(class_inputs_dict.keys()) != set(class_sig.keys()):
if set(class_inputs_dict.keys()) != set(model.input_signature.keys()):
continue # Some models have inputs that the preparation functions don't create, we skip those
with tempfile.TemporaryDirectory() as tmpdirname: