Safetensors tf (#19900)
* Wip * Add safetensors support for TensorFlow * First tests * Add final test for now * Retrigger CI like this * Update src/transformers/modeling_tf_utils.py Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -43,13 +43,14 @@ from transformers.testing_utils import ( # noqa: F401
|
||||
_tf_gpu_memory_limit,
|
||||
is_pt_tf_cross_test,
|
||||
is_staging_test,
|
||||
require_safetensors,
|
||||
require_tf,
|
||||
require_tf2onnx,
|
||||
slow,
|
||||
tooslow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
|
||||
from transformers.utils.generic import ModelOutput
|
||||
|
||||
|
||||
@@ -94,12 +95,7 @@ if is_tf_available():
|
||||
TFSampleDecoderOnlyOutput,
|
||||
TFSampleEncoderDecoderOutput,
|
||||
)
|
||||
from transformers.modeling_tf_utils import (
|
||||
TF2_WEIGHTS_INDEX_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
tf_shard_checkpoint,
|
||||
unpack_inputs,
|
||||
)
|
||||
from transformers.modeling_tf_utils import tf_shard_checkpoint, unpack_inputs
|
||||
from transformers.tf_utils import stable_softmax
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
@@ -119,6 +115,8 @@ if is_tf_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import BertModel
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
@@ -2168,7 +2166,7 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_special_layer_name_shardind(self):
|
||||
def test_special_layer_name_sharding(self):
|
||||
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
|
||||
model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
|
||||
|
||||
@@ -2268,6 +2266,54 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
self.assertTrue("custom_signature_1" in list(model_loaded.signatures.keys()))
|
||||
self.assertTrue("custom_signature_2" in list(model_loaded.signatures.keys()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_save_and_load(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
# No tf_model.h5 file, only a model.safetensors
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_safetensors_save_and_load_pt_to_tf(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pt_model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
# Check we have a model.safetensors file
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub(self):
|
||||
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Can load from the TF-formatted checkpoint
|
||||
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors-tf")
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
|
||||
@require_tf
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user