More TF int dtype fixes (#20384)
* Add a test to ensure int dummy inputs are int64 * Move the test into the existing int64 test and update a lot of existing dummies * Fix remaining dummies * Fix remaining dummies * Test for int64 serving sigs as well * Update core tests to use tf.int64 * Add better messages to the assertions * Update all serving sigs to int64 * More sneaky hiding tf.int32s * Add an optional int32 signature in save_pretrained * make fixup * Add Amy's suggestions * Switch all serving sigs back to tf.int32 * Switch all dummies to tf.int32 * Adjust tests to check for tf.int32 instead of tf.int64 * Fix base dummy_inputs dtype * Start casting to tf.int32 in input_processing * Change dtype for unpack_inputs test * Add proper tf.int32 test * Make the alternate serving signature int64
This commit is contained in:
@@ -1643,7 +1643,7 @@ class TFModelTesterMixin:
|
||||
if metrics:
|
||||
self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")
|
||||
|
||||
def test_int64_inputs(self):
|
||||
def test_int_support(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
prepared_for_class = self._prepare_for_class(
|
||||
@@ -1662,6 +1662,26 @@ class TFModelTesterMixin:
|
||||
}
|
||||
model = model_class(config)
|
||||
model(**prepared_for_class) # No assertion, we're just checking this doesn't throw an error
|
||||
int32_prepared_for_class = {
|
||||
key: tf.cast(tensor, tf.int32) if isinstance(tensor, tf.Tensor) and tensor.dtype.is_integer else tensor
|
||||
for key, tensor in prepared_for_class.items()
|
||||
}
|
||||
model(**int32_prepared_for_class) # No assertion, we're just checking this doesn't throw an error
|
||||
|
||||
# After testing that the model accepts all int inputs, confirm that its dummies are int32
|
||||
for key, tensor in model.dummy_inputs.items():
|
||||
self.assertTrue(isinstance(tensor, tf.Tensor), "Dummy inputs should be tf.Tensor!")
|
||||
if tensor.dtype.is_integer:
|
||||
self.assertTrue(tensor.dtype == tf.int32, "Integer dummy inputs should be tf.int32!")
|
||||
|
||||
# Also confirm that the serving sig uses int32
|
||||
if hasattr(model, "serving"):
|
||||
serving_sig = model.serving.input_signature
|
||||
for key, tensor_spec in serving_sig[0].items():
|
||||
if tensor_spec.dtype.is_integer:
|
||||
self.assertTrue(
|
||||
tensor_spec.dtype == tf.int32, "Serving signatures should use tf.int32 for ints!"
|
||||
)
|
||||
|
||||
def test_generate_with_headmasking(self):
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
@@ -2005,9 +2025,9 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
return pixel_values, output_attentions, output_hidden_states, return_dict
|
||||
|
||||
dummy_model = DummyModel()
|
||||
input_ids = tf.constant([0, 1, 2, 3], dtype=tf.int64)
|
||||
past_key_values = tf.constant([4, 5, 6, 7], dtype=tf.int64)
|
||||
pixel_values = tf.constant([8, 9, 10, 11], dtype=tf.int64)
|
||||
input_ids = tf.constant([0, 1, 2, 3], dtype=tf.int32)
|
||||
past_key_values = tf.constant([4, 5, 6, 7], dtype=tf.int32)
|
||||
pixel_values = tf.constant([8, 9, 10, 11], dtype=tf.int32)
|
||||
|
||||
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
|
||||
output = dummy_model.call(input_ids=input_ids, past_key_values=past_key_values)
|
||||
|
||||
Reference in New Issue
Block a user