More TF int dtype fixes (#20384)
* Add a test to ensure int dummy inputs are int64 * Move the test into the existing int64 test and update a lot of existing dummies * Fix remaining dummies * Fix remaining dummies * Test for int64 serving sigs as well * Update core tests to use tf.int64 * Add better messages to the assertions * Update all serving sigs to int64 * More sneaky hiding tf.int32s * Add an optional int32 signature in save_pretrained * make fixup * Add Amy's suggestions * Switch all serving sigs back to tf.int32 * Switch all dummies to tf.int32 * Adjust tests to check for tf.int32 instead of tf.int64 * Fix base dummy_inputs dtype * Start casting to tf.int32 in input_processing * Change dtype for unpack_inputs test * Add proper tf.int32 test * Make the alternate serving signature int64
This commit is contained in:
@@ -821,7 +821,7 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
||||
Returns:
|
||||
`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||
"""
|
||||
dummy = {"input_ids": tf.constant(DUMMY_INPUTS)}
|
||||
dummy = {"input_ids": tf.constant(DUMMY_INPUTS, dtype=tf.int64)}
|
||||
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
|
||||
if self.config.add_cross_attention:
|
||||
batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
|
||||
@@ -1365,7 +1365,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
||||
Returns:
|
||||
tf.Tensor with dummy inputs
|
||||
"""
|
||||
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
|
||||
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int64)}
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
||||
|
||||
Reference in New Issue
Block a user