Fix TF nightly tests (#20507)

* Fixed test_saved_model_extended

* Fix TFGPT2 tests

* make fixup

* Make sure keras-nlp utils are available for type hinting too

* Update src/transformers/testing_utils.py

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>

* make fixup

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Matt
2022-11-30 14:47:54 +00:00
committed by GitHub
parent 761b3fad92
commit afad0c18d9
4 changed files with 22 additions and 6 deletions

View File

@@ -2,12 +2,12 @@ import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from transformers import AutoConfig, TFGPT2LMHeadModel, is_tensorflow_text_available, is_tf_available
from transformers import AutoConfig, TFGPT2LMHeadModel, is_keras_nlp_available, is_tf_available
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.testing_utils import require_tensorflow_text, slow
from transformers.testing_utils import require_keras_nlp, slow
if is_tensorflow_text_available():
if is_keras_nlp_available():
from transformers.models.gpt2 import TFGPT2Tokenizer
if is_tf_available():
@@ -40,7 +40,7 @@ if is_tf_available():
return outputs
@require_tensorflow_text
@require_keras_nlp
class GPTTokenizationTest(unittest.TestCase):
# The TF tokenizers are usually going to be used as pretrained tokenizers from existing model checkpoints,
# so that's what we focus on here.

View File

@@ -218,11 +218,17 @@ class TFCoreModelTesterMixin:
model = model_class(config)
num_out = len(model(class_inputs_dict))
for key in class_inputs_dict.keys():
for key in list(class_inputs_dict.keys()):
# Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them
if key not in model.serving.input_signature[0]:
del class_inputs_dict[key]
# Check it's a tensor, in case the inputs dict has some bools in it too
if isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
if set(class_inputs_dict.keys()) != set(model.serving.input_signature[0].keys()):
continue # Some models have inputs that the preparation functions don't create, we skip those
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True)
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")