TF version compatibility fixes (#23663)
* New TF version compatibility fixes * Remove dummy print statement, move expand_1d * Make a proper framework inference function * Make a proper framework inference function * ValueError -> TypeError
This commit is contained in:
@@ -38,7 +38,7 @@ from .activations_tf import get_tf_activation
|
|||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .generation import GenerationConfig, TFGenerationMixin
|
from .generation import GenerationConfig, TFGenerationMixin
|
||||||
from .tf_utils import shape_list
|
from .tf_utils import expand_1d, load_attributes_from_hdf5_group, save_attributes_to_hdf5_group, shape_list
|
||||||
from .utils import (
|
from .utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
@@ -65,16 +65,15 @@ from .utils import (
|
|||||||
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
||||||
|
|
||||||
|
|
||||||
if parse(tf.__version__) >= parse("2.11.0"):
|
if parse(tf.__version__).minor >= 13:
|
||||||
|
from keras import backend as K
|
||||||
|
from keras.__internal__ import KerasTensor
|
||||||
|
elif parse(tf.__version__).minor >= 11:
|
||||||
from keras import backend as K
|
from keras import backend as K
|
||||||
from keras.engine import data_adapter
|
|
||||||
from keras.engine.keras_tensor import KerasTensor
|
from keras.engine.keras_tensor import KerasTensor
|
||||||
from keras.saving.legacy import hdf5_format
|
|
||||||
else:
|
else:
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras.engine import data_adapter
|
|
||||||
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
|
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
|
||||||
from tensorflow.python.keras.saving import hdf5_format
|
|
||||||
|
|
||||||
|
|
||||||
if is_safetensors_available():
|
if is_safetensors_available():
|
||||||
@@ -797,9 +796,7 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch
|
|||||||
try:
|
try:
|
||||||
with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
|
with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
|
||||||
# Retrieve the name of each layer from the H5 file
|
# Retrieve the name of each layer from the H5 file
|
||||||
saved_h5_model_layers_name = set(
|
saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names"))
|
||||||
hdf5_format.load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")
|
|
||||||
)
|
|
||||||
weight_value_tuples = []
|
weight_value_tuples = []
|
||||||
|
|
||||||
# Compute missing and unexpected sub layers
|
# Compute missing and unexpected sub layers
|
||||||
@@ -898,9 +895,7 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
|
|||||||
# Read the H5 file
|
# Read the H5 file
|
||||||
with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
|
with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
|
||||||
# Retrieve the name of each layer from the H5 file
|
# Retrieve the name of each layer from the H5 file
|
||||||
saved_h5_model_layers_name = set(
|
saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names"))
|
||||||
hdf5_format.load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find the missing layers from the high level list of layers
|
# Find the missing layers from the high level list of layers
|
||||||
missing_layers = list({layer.name for layer in model.layers} - saved_h5_model_layers_name)
|
missing_layers = list({layer.name for layer in model.layers} - saved_h5_model_layers_name)
|
||||||
@@ -924,7 +919,7 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
|
|||||||
|
|
||||||
# Create a dict from the H5 saved model that looks like {"weight_name": weight_value}
|
# Create a dict from the H5 saved model that looks like {"weight_name": weight_value}
|
||||||
# And a set with only the names
|
# And a set with only the names
|
||||||
for weight_name in hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names"):
|
for weight_name in load_attributes_from_hdf5_group(h5_layer_object, "weight_names"):
|
||||||
# TF names always start with the model name so we ignore it
|
# TF names always start with the model name so we ignore it
|
||||||
name = "/".join(weight_name.split("/")[1:])
|
name = "/".join(weight_name.split("/")[1:])
|
||||||
|
|
||||||
@@ -1528,8 +1523,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
output_to_label = {val: key for key, val in label_to_output.items()}
|
output_to_label = {val: key for key, val in label_to_output.items()}
|
||||||
if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
|
if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
|
||||||
# Newer TF train steps leave this out
|
# Newer TF train steps leave this out
|
||||||
data = data_adapter.expand_1d(data)
|
data = expand_1d(data)
|
||||||
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||||||
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
|
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
|
||||||
# them during input/label pre-processing. This avoids surprising the user by wrecking their data.
|
# them during input/label pre-processing. This avoids surprising the user by wrecking their data.
|
||||||
# In addition, modifying mutable Python inputs makes XLA compilation impossible.
|
# In addition, modifying mutable Python inputs makes XLA compilation impossible.
|
||||||
@@ -1635,8 +1630,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
output_to_label = {val: key for key, val in label_to_output.items()}
|
output_to_label = {val: key for key, val in label_to_output.items()}
|
||||||
if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
|
if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
|
||||||
# Newer versions leave this out
|
# Newer versions leave this out
|
||||||
data = data_adapter.expand_1d(data)
|
data = expand_1d(data)
|
||||||
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
|
||||||
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
|
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
|
||||||
# them during input/label pre-processing. This avoids surprising the user by wrecking their data.
|
# them during input/label pre-processing. This avoids surprising the user by wrecking their data.
|
||||||
# In addition, modifying mutable Python inputs makes XLA compilation impossible.
|
# In addition, modifying mutable Python inputs makes XLA compilation impossible.
|
||||||
@@ -2402,7 +2397,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
)
|
)
|
||||||
param_dset[:] = layer.numpy()
|
param_dset[:] = layer.numpy()
|
||||||
layers.append(layer_name.encode("utf8"))
|
layers.append(layer_name.encode("utf8"))
|
||||||
hdf5_format.save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
|
save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
self._upload_modified_files(
|
self._upload_modified_files(
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
import collections
|
import collections
|
||||||
import csv
|
import csv
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
@@ -36,7 +35,7 @@ from ..image_processing_utils import BaseImageProcessor
|
|||||||
from ..modelcard import ModelCard
|
from ..modelcard import ModelCard
|
||||||
from ..models.auto.configuration_auto import AutoConfig
|
from ..models.auto.configuration_auto import AutoConfig
|
||||||
from ..tokenization_utils import PreTrainedTokenizer
|
from ..tokenization_utils import PreTrainedTokenizer
|
||||||
from ..utils import ModelOutput, add_end_docstrings, is_tf_available, is_torch_available, logging
|
from ..utils import ModelOutput, add_end_docstrings, infer_framework, is_tf_available, is_torch_available, logging
|
||||||
|
|
||||||
|
|
||||||
GenericTensor = Union[List["GenericTensor"], "torch.Tensor", "tf.Tensor"]
|
GenericTensor = Union[List["GenericTensor"], "torch.Tensor", "tf.Tensor"]
|
||||||
@@ -278,7 +277,7 @@ def infer_framework_load_model(
|
|||||||
if isinstance(model, str):
|
if isinstance(model, str):
|
||||||
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
|
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
|
||||||
|
|
||||||
framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
|
framework = infer_framework(model.__class__)
|
||||||
return framework, model
|
return framework, model
|
||||||
|
|
||||||
|
|
||||||
@@ -351,7 +350,7 @@ def get_framework(model, revision: Optional[str] = None):
|
|||||||
except OSError:
|
except OSError:
|
||||||
model = TFAutoModel.from_pretrained(model, revision=revision)
|
model = TFAutoModel.from_pretrained(model, revision=revision)
|
||||||
|
|
||||||
framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
|
framework = infer_framework(model.__class__)
|
||||||
return framework
|
return framework
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -166,3 +166,90 @@ def check_embeddings_within_bounds(tensor: tf.Tensor, embed_dim: int, tensor_nam
|
|||||||
f"layer's input dimension ({embed_dim}). The likely cause is some problem at tokenization time."
|
f"layer's input dimension ({embed_dim}). The likely cause is some problem at tokenization time."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_attributes_to_hdf5_group(group, name, data):
|
||||||
|
"""Saves attributes (data) of the specified name into the HDF5 group.
|
||||||
|
|
||||||
|
This method deals with an inherent problem of HDF5 file which is not able to store data larger than
|
||||||
|
HDF5_OBJECT_HEADER_LIMIT bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group: A pointer to a HDF5 group.
|
||||||
|
name: A name of the attributes to save.
|
||||||
|
data: Attributes data to store.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If any single attribute is too large to be saved.
|
||||||
|
|
||||||
|
Copied from Keras to Transformers to avoid versioning issues.
|
||||||
|
"""
|
||||||
|
HDF5_OBJECT_HEADER_LIMIT = 64512
|
||||||
|
# Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
|
||||||
|
# because in that case even chunking the array would not make the saving
|
||||||
|
# possible.
|
||||||
|
bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]
|
||||||
|
|
||||||
|
# Expecting this to never be true.
|
||||||
|
if bad_attributes:
|
||||||
|
raise RuntimeError(
|
||||||
|
"The following attributes cannot be saved to HDF5 file because "
|
||||||
|
f"they are larger than {HDF5_OBJECT_HEADER_LIMIT} "
|
||||||
|
f"bytes: {bad_attributes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
data_npy = np.asarray(data)
|
||||||
|
|
||||||
|
num_chunks = 1
|
||||||
|
chunked_data = np.array_split(data_npy, num_chunks)
|
||||||
|
|
||||||
|
# This will never loop forever thanks to the test above.
|
||||||
|
while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data):
|
||||||
|
num_chunks += 1
|
||||||
|
chunked_data = np.array_split(data_npy, num_chunks)
|
||||||
|
|
||||||
|
if num_chunks > 1:
|
||||||
|
for chunk_id, chunk_data in enumerate(chunked_data):
|
||||||
|
group.attrs["%s%d" % (name, chunk_id)] = chunk_data
|
||||||
|
else:
|
||||||
|
group.attrs[name] = data
|
||||||
|
|
||||||
|
|
||||||
|
def load_attributes_from_hdf5_group(group, name):
|
||||||
|
"""Loads attributes of the specified name from the HDF5 group.
|
||||||
|
|
||||||
|
This method deals with an inherent problem of HDF5 file which is not able to store data larger than
|
||||||
|
HDF5_OBJECT_HEADER_LIMIT bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group: A pointer to a HDF5 group.
|
||||||
|
name: A name of the attributes to load.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
data: Attributes data.
|
||||||
|
|
||||||
|
Copied from Keras to Transformers to avoid versioning issues.
|
||||||
|
"""
|
||||||
|
if name in group.attrs:
|
||||||
|
data = [n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs[name]]
|
||||||
|
else:
|
||||||
|
data = []
|
||||||
|
chunk_id = 0
|
||||||
|
while "%s%d" % (name, chunk_id) in group.attrs:
|
||||||
|
data.extend(
|
||||||
|
[n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs["%s%d" % (name, chunk_id)]]
|
||||||
|
)
|
||||||
|
chunk_id += 1
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def expand_1d(data):
|
||||||
|
"""Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s.
|
||||||
|
Copied from Keras to here to avoid versioning issues."""
|
||||||
|
|
||||||
|
def _expand_single_1d_tensor(t):
|
||||||
|
if isinstance(t, tf.Tensor) and t.shape.rank == 1:
|
||||||
|
return tf.expand_dims(t, axis=-1)
|
||||||
|
return t
|
||||||
|
|
||||||
|
return tf.nest.map_structure(_expand_single_1d_tensor, data)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from .generic import (
|
|||||||
expand_dims,
|
expand_dims,
|
||||||
find_labels,
|
find_labels,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
|
infer_framework,
|
||||||
is_jax_tensor,
|
is_jax_tensor,
|
||||||
is_numpy_array,
|
is_numpy_array,
|
||||||
is_tensor,
|
is_tensor,
|
||||||
|
|||||||
@@ -398,11 +398,10 @@ def can_return_loss(model_class):
|
|||||||
Args:
|
Args:
|
||||||
model_class (`type`): The class of the model.
|
model_class (`type`): The class of the model.
|
||||||
"""
|
"""
|
||||||
base_classes = str(inspect.getmro(model_class))
|
framework = infer_framework(model_class)
|
||||||
|
if framework == "tf":
|
||||||
if "keras.engine.training.Model" in base_classes:
|
|
||||||
signature = inspect.signature(model_class.call) # TensorFlow models
|
signature = inspect.signature(model_class.call) # TensorFlow models
|
||||||
elif "torch.nn.modules.module.Module" in base_classes:
|
elif framework == "pt":
|
||||||
signature = inspect.signature(model_class.forward) # PyTorch models
|
signature = inspect.signature(model_class.forward) # PyTorch models
|
||||||
else:
|
else:
|
||||||
signature = inspect.signature(model_class.__call__) # Flax models
|
signature = inspect.signature(model_class.__call__) # Flax models
|
||||||
@@ -422,11 +421,10 @@ def find_labels(model_class):
|
|||||||
model_class (`type`): The class of the model.
|
model_class (`type`): The class of the model.
|
||||||
"""
|
"""
|
||||||
model_name = model_class.__name__
|
model_name = model_class.__name__
|
||||||
base_classes = str(inspect.getmro(model_class))
|
framework = infer_framework(model_class)
|
||||||
|
if framework == "tf":
|
||||||
if "keras.engine.training.Model" in base_classes:
|
|
||||||
signature = inspect.signature(model_class.call) # TensorFlow models
|
signature = inspect.signature(model_class.call) # TensorFlow models
|
||||||
elif "torch.nn.modules.module.Module" in base_classes:
|
elif framework == "pt":
|
||||||
signature = inspect.signature(model_class.forward) # PyTorch models
|
signature = inspect.signature(model_class.forward) # PyTorch models
|
||||||
else:
|
else:
|
||||||
signature = inspect.signature(model_class.__call__) # Flax models
|
signature = inspect.signature(model_class.__call__) # Flax models
|
||||||
@@ -565,3 +563,21 @@ def add_model_info_to_auto_map(auto_map, repo_id):
|
|||||||
auto_map[key] = f"{repo_id}--{value}"
|
auto_map[key] = f"{repo_id}--{value}"
|
||||||
|
|
||||||
return auto_map
|
return auto_map
|
||||||
|
|
||||||
|
|
||||||
|
def infer_framework(model_class):
|
||||||
|
"""
|
||||||
|
Infers the framework of a given model without using isinstance(), because we cannot guarantee that the relevant
|
||||||
|
classes are imported or available.
|
||||||
|
"""
|
||||||
|
for base_class in inspect.getmro(model_class):
|
||||||
|
module = base_class.__module__
|
||||||
|
name = base_class.__name__
|
||||||
|
if module.startswith("tensorflow") or module.startswith("keras") or name == "TFPreTrainedModel":
|
||||||
|
return "tf"
|
||||||
|
elif module.startswith("torch") or name == "PreTrainedModel":
|
||||||
|
return "pt"
|
||||||
|
elif module.startswith("flax") or module.startswith("jax") or name == "FlaxPreTrainedModel":
|
||||||
|
return "flax"
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Could not infer framework from class {model_class}.")
|
||||||
|
|||||||
Reference in New Issue
Block a user