Safetensors serialization by default (#27064)

* Safetensors serialization by default

* First pass on the tests

* Second pass on the tests

* Third pass on the tests

* Fix TF weight loading from TF-format safetensors

* Specific encoder-decoder fixes for weight crossloading

* Add VisionEncoderDecoder fixes for TF too

* Change filename test for pt-to-tf

* One missing fix for TFVisionEncoderDecoder

* Fix the other crossload test

* Support for flax + updated tests

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Sanchit's comments

* Sanchit's comments 2

* Nico's comments

* Fix tests

* cleanup

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: Matt <rocketknight1@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Lysandre Debut
2023-10-31 19:16:49 +01:00
committed by GitHub
parent 25e6e9418c
commit 113ebf80ac
20 changed files with 433 additions and 137 deletions

View File

@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import glob
import json
import os
@@ -42,7 +42,9 @@ from transformers.testing_utils import (
TestCasePlus,
is_staging_test,
require_accelerate,
require_flax,
require_safetensors,
require_tf,
require_torch,
require_torch_accelerator,
require_torch_multi_accelerator,
@@ -56,7 +58,7 @@ from transformers.utils import (
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from transformers.utils.import_utils import is_torchdynamo_available
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torchdynamo_available
sys.path.append(str(Path(__file__).parent.parent / "utils"))
@@ -66,6 +68,7 @@ from test_module.custom_configuration import CustomConfig, NoSuperInitConfig #
if is_torch_available():
import torch
from safetensors.torch import save_file as safe_save_file
from test_module.custom_modeling import CustomModel, NoSuperInitModel
from torch import nn
@@ -146,6 +149,13 @@ if is_torch_available():
self.decoder.weight = self.base.linear.weight
if is_flax_available():
from transformers import FlaxBertModel
if is_tf_available():
from transformers import TFBertModel
TINY_T5 = "patrickvonplaten/t5-tiny-random"
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
@@ -420,13 +430,13 @@ class ModelUtilsTest(TestCasePlus):
},
)
def test_checkpoint_sharding_local(self):
def test_checkpoint_sharding_local_bin(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]:
model.save_pretrained(tmp_dir, max_shard_size=max_size)
model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=False)
# Get each shard file and its size
shard_to_size = {}
@@ -472,11 +482,11 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
def test_checkpoint_variant_local(self):
def test_checkpoint_variant_local_bin(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, variant="v2")
model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False)
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
@@ -492,11 +502,11 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
def test_checkpoint_variant_local_sharded(self):
def test_checkpoint_variant_local_sharded_bin(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB")
model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB", safe_serialization=False)
weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
weights_index_file = os.path.join(tmp_dir, weights_index_name)
@@ -604,18 +614,18 @@ class ModelUtilsTest(TestCasePlus):
)
self.assertIsNotNone(model)
def test_checkpoint_variant_save_load(self):
def test_checkpoint_variant_save_load_bin(self):
with tempfile.TemporaryDirectory() as tmp_dir:
model = BertModel.from_pretrained(
"hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
)
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
model.save_pretrained(tmp_dir, variant="v2")
model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False)
# saving will create a variant checkpoint
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
model.save_pretrained(tmp_dir)
model.save_pretrained(tmp_dir, safe_serialization=False)
# saving shouldn't delete variant checkpoints
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
@@ -874,7 +884,7 @@ class ModelUtilsTest(TestCasePlus):
def test_base_model_to_head_model_load(self):
base_model = BaseModel(PretrainedConfig())
with tempfile.TemporaryDirectory() as tmp_dir:
base_model.save_pretrained(tmp_dir)
base_model.save_pretrained(tmp_dir, safe_serialization=False)
# Can load a base model in a model with head
model = ModelWithHead.from_pretrained(tmp_dir)
@@ -886,7 +896,7 @@ class ModelUtilsTest(TestCasePlus):
head_state_dict = model.state_dict()
base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"]
base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"]
torch.save(base_state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
safe_save_file(base_state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
with self.assertRaisesRegex(
ValueError, "The state dictionary of the model you are trying to load is corrupted."
@@ -934,8 +944,8 @@ class ModelUtilsTest(TestCasePlus):
# Loading the model with the same class, we do get a warning for unexpected weights
state_dict = model.state_dict()
state_dict["added_key"] = state_dict["linear.weight"]
torch.save(state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
state_dict["added_key"] = copy.deepcopy(state_dict["linear.weight"])
safe_save_file(state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
with CaptureLogger(logger) as cl:
_, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True)
self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out)
@@ -1072,6 +1082,54 @@ class ModelUtilsTest(TestCasePlus):
)
self.assertEqual(model.generation_config.transformers_version, "foo")
@require_safetensors
def test_safetensors_torch_from_torch(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
new_model = BertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
@require_safetensors
@require_flax
def test_safetensors_torch_from_flax(self):
hub_model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
new_model = BertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(hub_model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
@require_tf
@require_safetensors
def test_safetensors_torch_from_tf(self):
hub_model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
new_model = BertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(hub_model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
@require_safetensors
def test_safetensors_torch_from_torch_sharded(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB")
new_model = BertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
@require_torch
@is_staging_test