More TF int dtype fixes (#20384)
* Add a test to ensure int dummy inputs are int64 * Move the test into the existing int64 test and update a lot of existing dummies * Fix remaining dummies * Fix remaining dummies * Test for int64 serving sigs as well * Update core tests to use tf.int64 * Add better messages to the assertions * Update all serving sigs to int64 * More sneaky hiding tf.int32s * Add an optional int32 signature in save_pretrained * make fixup * Add Amy's suggestions * Switch all serving sigs back to tf.int32 * Switch all dummies to tf.int32 * Adjust tests to check for tf.int32 instead of tf.int64 * Fix base dummy_inputs dtype * Start casting to tf.int32 in input_processing * Change dtype for unpack_inputs test * Add proper tf.int32 test * Make the alternate serving signature int64
This commit is contained in:
@@ -218,6 +218,11 @@ class TFCoreModelTesterMixin:
|
||||
model = model_class(config)
|
||||
num_out = len(model(class_inputs_dict))
|
||||
|
||||
for key in class_inputs_dict.keys():
|
||||
# Check it's a tensor, in case the inputs dict has some bools in it too
|
||||
if isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
|
||||
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, saved_model=True)
|
||||
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
|
||||
|
||||
Reference in New Issue
Block a user