Safetensors tf (#19900)

* Wip

* Add safetensors support for TensorFlow

* First tests

* Add final test for now

* Retrigger CI like this

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Sylvain Gugger
2022-10-27 15:56:29 -04:00
committed by GitHub
parent e4132952a1
commit 6c24443ff5
5 changed files with 208 additions and 15 deletions

View File

@@ -43,13 +43,14 @@ from transformers.testing_utils import ( # noqa: F401
_tf_gpu_memory_limit,
is_pt_tf_cross_test,
is_staging_test,
require_safetensors,
require_tf,
require_tf2onnx,
slow,
tooslow,
torch_device,
)
from transformers.utils import logging
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
from transformers.utils.generic import ModelOutput
@@ -94,12 +95,7 @@ if is_tf_available():
TFSampleDecoderOnlyOutput,
TFSampleEncoderDecoderOutput,
)
from transformers.modeling_tf_utils import (
TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME,
tf_shard_checkpoint,
unpack_inputs,
)
from transformers.modeling_tf_utils import tf_shard_checkpoint, unpack_inputs
from transformers.tf_utils import stable_softmax
if _tf_gpu_memory_limit is not None:
@@ -119,6 +115,8 @@ if is_tf_available():
if is_torch_available():
import torch
from transformers import BertModel
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
@@ -2168,7 +2166,7 @@ class UtilsFunctionsTest(unittest.TestCase):
)
@slow
def test_special_layer_name_shardind(self):
def test_special_layer_name_sharding(self):
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
@@ -2268,6 +2266,54 @@ class UtilsFunctionsTest(unittest.TestCase):
self.assertTrue("custom_signature_1" in list(model_loaded.signatures.keys()))
self.assertTrue("custom_signature_2" in list(model_loaded.signatures.keys()))
@require_safetensors
def test_safetensors_save_and_load(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
# No tf_model.h5 file, only a model.safetensors
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
new_model = TFBertModel.from_pretrained(tmp_dir)
# Check models are equal
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@is_pt_tf_cross_test
def test_safetensors_save_and_load_pt_to_tf(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
pt_model.save_pretrained(tmp_dir, safe_serialization=True)
# Check we have a model.safetensors file
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
new_model = TFBertModel.from_pretrained(tmp_dir)
# Check models are equal
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_load_from_hub(self):
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Can load from the TF-formatted checkpoint
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors-tf")
# Check models are equal
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
# Can load from the PyTorch-formatted checkpoint
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
# Check models are equal
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_tf
@is_staging_test