Fix TF nightly tests (#20507)
* Fixed test_saved_model_extended * Fix TFGPT2 tests * make fixup * Make sure keras-nlp utils are available for type hinting too * Update src/transformers/testing_utils.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * make fixup Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -218,11 +218,17 @@ class TFCoreModelTesterMixin:
|
||||
model = model_class(config)
|
||||
num_out = len(model(class_inputs_dict))
|
||||
|
||||
for key in class_inputs_dict.keys():
|
||||
for key in list(class_inputs_dict.keys()):
|
||||
# Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them
|
||||
if key not in model.serving.input_signature[0]:
|
||||
del class_inputs_dict[key]
|
||||
# Check it's a tensor, in case the inputs dict has some bools in it too
|
||||
if isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
|
||||
elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
|
||||
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
|
||||
|
||||
if set(class_inputs_dict.keys()) != set(model.serving.input_signature[0].keys()):
|
||||
continue # Some models have inputs that the preparation functions don't create, we skip those
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, saved_model=True)
|
||||
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
|
||||
|
||||
Reference in New Issue
Block a user