Stop storing references to bound methods via tf.function (#24146)

* Stop storing references to bound methods in tf.functions

* Remove the gc.collect calls now that we resolved the underlying problem

* Remove the default signature from model.serving entirely, big cleanup

* Remove _prune_signature as self.input_signature can prune itself

* Restore serving docstring

* Update int support test to check the input signature

* Make sure other tests also use model.input_signature and not serving.input_signature

* Restore _prune_signature

* Remove the doctest GC now it's no longer needed

* Correct core tests to use the pruned sig

* order lines correctly in core tests

* Add eager_serving back with a deprecation warning
This commit is contained in:
Matt
2023-06-13 19:04:22 +01:00
committed by GitHub
parent b979a2064d
commit 3bd1fe4315
8 changed files with 34 additions and 48 deletions

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import gc
import json
import os
import shutil
@@ -551,11 +550,6 @@ class TFRagDPRBartTest(TFRagTestMixin, unittest.TestCase):
@require_sentencepiece
@require_tokenizers
class TFRagModelIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
@cached_property
def token_model(self):
return TFRagTokenForGeneration.from_pretrained_question_encoder_generator(

View File

@@ -17,7 +17,6 @@
from __future__ import annotations
import gc
import inspect
import unittest
@@ -431,11 +430,6 @@ def prepare_dog_img():
@require_tf
@slow
class TFSamModelIntegrationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
def test_inference_mask_generation_no_point(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

View File

@@ -15,7 +15,6 @@
from __future__ import annotations
import gc
import unittest
from transformers import XGLMConfig, XGLMTokenizer, is_tf_available
@@ -173,11 +172,6 @@ class TFXGLMModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
@require_tf
class TFXGLMModelLanguageGenerationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
@slow
def test_lm_generate_xglm(self, verify_outputs=True):
model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M")

View File

@@ -1687,14 +1687,10 @@ class TFModelTesterMixin:
if tensor.dtype.is_integer:
self.assertTrue(tensor.dtype == tf.int32, "Integer dummy inputs should be tf.int32!")
# Also confirm that the serving sig uses int32
if hasattr(model, "serving"):
serving_sig = model.serving.input_signature
for key, tensor_spec in serving_sig[0].items():
if tensor_spec.dtype.is_integer:
self.assertTrue(
tensor_spec.dtype == tf.int32, "Serving signatures should use tf.int32 for ints!"
)
# Also confirm that the input_signature uses int32
for key, tensor_spec in model.input_signature.items():
if tensor_spec.dtype.is_integer:
self.assertTrue(tensor_spec.dtype == tf.int32, "Input signatures should use tf.int32 for ints!")
def test_generate_with_headmasking(self):
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]

View File

@@ -217,17 +217,18 @@ class TFCoreModelTesterMixin:
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
class_sig = model._prune_signature(model.input_signature)
num_out = len(model(class_inputs_dict))
for key in list(class_inputs_dict.keys()):
# Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them
if key not in model.serving.input_signature[0]:
if key not in class_sig:
del class_inputs_dict[key]
# Check it's a tensor, in case the inputs dict has some bools in it too
elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
if set(class_inputs_dict.keys()) != set(model.serving.input_signature[0].keys()):
if set(class_inputs_dict.keys()) != set(class_sig.keys()):
continue # Some models have inputs that the preparation functions don't create, we skip those
with tempfile.TemporaryDirectory() as tmpdirname: