Add test for proper TF input signatures (#24320)

* Add test for proper input signatures

* No more signature pruning

* Test the dummy inputs are valid too

* fine-tine -> fine-tune

* Fix indent in test_dataset_conversion
This commit is contained in:
Matt
2023-06-16 17:03:13 +01:00
committed by GitHub
parent bdfd57d1d1
commit 9138995025
6 changed files with 21 additions and 19 deletions

View File

@@ -1065,6 +1065,16 @@ class TFModelTesterMixin:
output_for_kw_input = model(**inputs_np)
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
def test_valid_input_signature_and_dummies(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
call_args = inspect.signature(model.call).parameters
for key in model.input_signature:
self.assertIn(key, call_args)
for key in model.dummy_inputs:
self.assertIn(key, call_args)
def test_resize_token_embeddings(self):
# TODO (joao): after the embeddings refactor is complete, rework this test so as to rely exclusively on
# tf.keras.layers.Embedding
@@ -1700,7 +1710,7 @@ class TFModelTesterMixin:
for tensor in test_batch.values():
self.assertTrue(isinstance(tensor, tf.Tensor))
self.assertEqual(len(tensor), len(input_dataset)) # Assert we didn't lose any data
model(test_batch, training=False)
model(test_batch, training=False)
if "labels" in inspect.signature(model_class.call).parameters.keys():
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True)