Stop storing references to bound methods via tf.function (#24146)
* Stop storing references to bound methods in tf.functions * Remove the gc.collect calls now that we resolved the underlying problem * Remove the default signature from model.serving entirely, big cleanup * Remove _prune_signature as self.input_signature can prune itself * Restore serving docstring * Update int support test to check the input signature * Make sure other tests also use model.input_signature and not serving.input_signature * Restore _prune_signature * Remove the doctest GC now it's no longer needed * Correct core tests to use the pruned sig * order lines correctly in core tests * Add eager_serving back with a deprecation warning
This commit is contained in:
@@ -1171,12 +1171,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.name_or_path = config.name_or_path
|
self.name_or_path = config.name_or_path
|
||||||
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
||||||
if not hasattr(self, "serving"): # Don't overwrite existing serving signatures
|
|
||||||
self.serving = tf.function(
|
|
||||||
self.eager_serving, input_signature=[self._prune_signature(self.input_signature)]
|
|
||||||
)
|
|
||||||
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
|
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
|
||||||
self._set_save_spec(self.serving.input_signature[0])
|
self._set_save_spec(self._prune_signature(self.input_signature))
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return self.config.to_dict()
|
return self.config.to_dict()
|
||||||
@@ -1226,15 +1222,31 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility
|
head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility
|
||||||
return head_mask
|
return head_mask
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def serving(self, inputs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
Method used for serving the model. Does not have a specific signature, but will be specialized as concrete
|
||||||
|
functions when saving with `save_pretrained`.
|
||||||
|
inputs (`Dict[str, tf.Tensor]`):
|
||||||
|
The input of the saved model as a dictionary of tensors.
|
||||||
|
"""
|
||||||
|
output = self.call(inputs)
|
||||||
|
|
||||||
|
return self.serving_output(output)
|
||||||
|
|
||||||
def eager_serving(self, inputs):
|
def eager_serving(self, inputs):
|
||||||
"""
|
"""
|
||||||
Method used for serving the model. Intended not to be compiled with a tf.function decorator so that we can use
|
Method used for serving the model. This method is deprecated, and will be removed.
|
||||||
it to generate multiple signatures later.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (`Dict[str, tf.Tensor]`):
|
inputs (`Dict[str, tf.Tensor]`):
|
||||||
The input of the saved model as a dictionary of tensors.
|
The input of the saved model as a dictionary of tensors.
|
||||||
"""
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
"The function `eager_serving` is deprecated and will be removed in version 4.32.0 of Transformers",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
output = self.call(inputs)
|
output = self.call(inputs)
|
||||||
|
|
||||||
return self.serving_output(output)
|
return self.serving_output(output)
|
||||||
@@ -2409,17 +2421,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str):
|
if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str):
|
||||||
self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1]
|
self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1]
|
||||||
if signatures is None:
|
if signatures is None:
|
||||||
if any(spec.dtype == tf.int32 for spec in self.serving.input_signature[0].values()):
|
sig = self._prune_signature(self.input_signature)
|
||||||
|
serving_default = self.serving.get_concrete_function(sig)
|
||||||
|
if any(spec.dtype == tf.int32 for spec in sig.values()):
|
||||||
int64_spec = {
|
int64_spec = {
|
||||||
key: tf.TensorSpec(
|
key: tf.TensorSpec(
|
||||||
shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name
|
shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name
|
||||||
)
|
)
|
||||||
for key, spec in self.serving.input_signature[0].items()
|
for key, spec in sig.items()
|
||||||
}
|
}
|
||||||
int64_serving = tf.function(self.eager_serving, input_signature=[int64_spec])
|
int64_serving = self.serving.get_concrete_function(int64_spec)
|
||||||
signatures = {"serving_default": self.serving, "int64_serving": int64_serving}
|
signatures = {"serving_default": serving_default, "int64_serving": int64_serving}
|
||||||
else:
|
else:
|
||||||
signatures = self.serving
|
signatures = serving_default
|
||||||
saved_model_dir = os.path.join(save_directory, "saved_model", str(version))
|
saved_model_dir = os.path.join(save_directory, "saved_model", str(version))
|
||||||
self.save(saved_model_dir, include_optimizer=False, signatures=signatures)
|
self.save(saved_model_dir, include_optimizer=False, signatures=signatures)
|
||||||
logger.info(f"Saved model created in {saved_model_dir}")
|
logger.info(f"Saved model created in {saved_model_dir}")
|
||||||
|
|||||||
@@ -1882,13 +1882,6 @@ def preprocess_string(string, skip_cuda_tests):
|
|||||||
if not is_cuda_found:
|
if not is_cuda_found:
|
||||||
modified_string = "".join(codeblocks)
|
modified_string = "".join(codeblocks)
|
||||||
|
|
||||||
if ">>>" in modified_string:
|
|
||||||
lines = modified_string.split("\n")
|
|
||||||
indent = len(lines[-1]) - len(lines[-1].lstrip())
|
|
||||||
|
|
||||||
cleanup = ">>> import gc; gc.collect() # doctest: +IGNORE_RESULT"
|
|
||||||
modified_string += "\n" + " " * indent + cleanup
|
|
||||||
|
|
||||||
return modified_string
|
return modified_string
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2676,7 +2676,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model")
|
self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model")
|
||||||
self.model._set_save_spec(inputs=self.serving.input_signature)
|
self.model._set_save_spec(self._prune_signature(self.input_signature))
|
||||||
self.use_cache = config.use_cache
|
self.use_cache = config.use_cache
|
||||||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
|
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
|
||||||
self.bias_layer = BiasLayer(
|
self.bias_layer = BiasLayer(
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import gc
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
@@ -551,11 +550,6 @@ class TFRagDPRBartTest(TFRagTestMixin, unittest.TestCase):
|
|||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
class TFRagModelIntegrationTests(unittest.TestCase):
|
class TFRagModelIntegrationTests(unittest.TestCase):
|
||||||
def tearDown(self):
|
|
||||||
super().tearDown()
|
|
||||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def token_model(self):
|
def token_model(self):
|
||||||
return TFRagTokenForGeneration.from_pretrained_question_encoder_generator(
|
return TFRagTokenForGeneration.from_pretrained_question_encoder_generator(
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import gc
|
|
||||||
import inspect
|
import inspect
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -431,11 +430,6 @@ def prepare_dog_img():
|
|||||||
@require_tf
|
@require_tf
|
||||||
@slow
|
@slow
|
||||||
class TFSamModelIntegrationTest(unittest.TestCase):
|
class TFSamModelIntegrationTest(unittest.TestCase):
|
||||||
def tearDown(self):
|
|
||||||
super().tearDown()
|
|
||||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
def test_inference_mask_generation_no_point(self):
|
def test_inference_mask_generation_no_point(self):
|
||||||
model = TFSamModel.from_pretrained("facebook/sam-vit-base")
|
model = TFSamModel.from_pretrained("facebook/sam-vit-base")
|
||||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import gc
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import XGLMConfig, XGLMTokenizer, is_tf_available
|
from transformers import XGLMConfig, XGLMTokenizer, is_tf_available
|
||||||
@@ -173,11 +172,6 @@ class TFXGLMModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFXGLMModelLanguageGenerationTest(unittest.TestCase):
|
class TFXGLMModelLanguageGenerationTest(unittest.TestCase):
|
||||||
def tearDown(self):
|
|
||||||
super().tearDown()
|
|
||||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_lm_generate_xglm(self, verify_outputs=True):
|
def test_lm_generate_xglm(self, verify_outputs=True):
|
||||||
model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M")
|
model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M")
|
||||||
|
|||||||
@@ -1687,14 +1687,10 @@ class TFModelTesterMixin:
|
|||||||
if tensor.dtype.is_integer:
|
if tensor.dtype.is_integer:
|
||||||
self.assertTrue(tensor.dtype == tf.int32, "Integer dummy inputs should be tf.int32!")
|
self.assertTrue(tensor.dtype == tf.int32, "Integer dummy inputs should be tf.int32!")
|
||||||
|
|
||||||
# Also confirm that the serving sig uses int32
|
# Also confirm that the input_signature uses int32
|
||||||
if hasattr(model, "serving"):
|
for key, tensor_spec in model.input_signature.items():
|
||||||
serving_sig = model.serving.input_signature
|
if tensor_spec.dtype.is_integer:
|
||||||
for key, tensor_spec in serving_sig[0].items():
|
self.assertTrue(tensor_spec.dtype == tf.int32, "Input signatures should use tf.int32 for ints!")
|
||||||
if tensor_spec.dtype.is_integer:
|
|
||||||
self.assertTrue(
|
|
||||||
tensor_spec.dtype == tf.int32, "Serving signatures should use tf.int32 for ints!"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_generate_with_headmasking(self):
|
def test_generate_with_headmasking(self):
|
||||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||||
|
|||||||
@@ -217,17 +217,18 @@ class TFCoreModelTesterMixin:
|
|||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
class_sig = model._prune_signature(model.input_signature)
|
||||||
num_out = len(model(class_inputs_dict))
|
num_out = len(model(class_inputs_dict))
|
||||||
|
|
||||||
for key in list(class_inputs_dict.keys()):
|
for key in list(class_inputs_dict.keys()):
|
||||||
# Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them
|
# Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them
|
||||||
if key not in model.serving.input_signature[0]:
|
if key not in class_sig:
|
||||||
del class_inputs_dict[key]
|
del class_inputs_dict[key]
|
||||||
# Check it's a tensor, in case the inputs dict has some bools in it too
|
# Check it's a tensor, in case the inputs dict has some bools in it too
|
||||||
elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
|
elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
|
||||||
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
|
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
|
||||||
|
|
||||||
if set(class_inputs_dict.keys()) != set(model.serving.input_signature[0].keys()):
|
if set(class_inputs_dict.keys()) != set(class_sig.keys()):
|
||||||
continue # Some models have inputs that the preparation functions don't create, we skip those
|
continue # Some models have inputs that the preparation functions don't create, we skip those
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
|||||||
Reference in New Issue
Block a user