Safetensors serialization by default (#27064)
* Safetensors serialization by default * First pass on the tests * Second pass on the tests * Third pass on the tests * Fix TF weight loading from TF-format safetensors * Specific encoder-decoder fixes for weight crossloading * Add VisionEncoderDecoder fixes for TF too * Change filename test for pt-to-tf * One missing fix for TFVisionEncoderDecoder * Fix the other crossload test * Support for flax + updated tests * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Sanchit's comments * Sanchit's comments 2 * Nico's comments * Fix tests * cleanup * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
@@ -42,7 +42,9 @@ from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
is_staging_test,
|
||||
require_accelerate,
|
||||
require_flax,
|
||||
require_safetensors,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_multi_accelerator,
|
||||
@@ -56,7 +58,7 @@ from transformers.utils import (
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
)
|
||||
from transformers.utils.import_utils import is_torchdynamo_available
|
||||
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torchdynamo_available
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
@@ -66,6 +68,7 @@ from test_module.custom_configuration import CustomConfig, NoSuperInitConfig #
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
from test_module.custom_modeling import CustomModel, NoSuperInitModel
|
||||
from torch import nn
|
||||
|
||||
@@ -146,6 +149,13 @@ if is_torch_available():
|
||||
self.decoder.weight = self.base.linear.weight
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers import FlaxBertModel
|
||||
|
||||
if is_tf_available():
|
||||
from transformers import TFBertModel
|
||||
|
||||
|
||||
TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
||||
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
|
||||
|
||||
@@ -420,13 +430,13 @@ class ModelUtilsTest(TestCasePlus):
|
||||
},
|
||||
)
|
||||
|
||||
def test_checkpoint_sharding_local(self):
|
||||
def test_checkpoint_sharding_local_bin(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||
for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]:
|
||||
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||
model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=False)
|
||||
|
||||
# Get each shard file and its size
|
||||
shard_to_size = {}
|
||||
@@ -472,11 +482,11 @@ class ModelUtilsTest(TestCasePlus):
|
||||
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_checkpoint_variant_local(self):
|
||||
def test_checkpoint_variant_local_bin(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, variant="v2")
|
||||
model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False)
|
||||
|
||||
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
||||
|
||||
@@ -492,11 +502,11 @@ class ModelUtilsTest(TestCasePlus):
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_checkpoint_variant_local_sharded(self):
|
||||
def test_checkpoint_variant_local_sharded_bin(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB")
|
||||
model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB", safe_serialization=False)
|
||||
|
||||
weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
|
||||
weights_index_file = os.path.join(tmp_dir, weights_index_name)
|
||||
@@ -604,18 +614,18 @@ class ModelUtilsTest(TestCasePlus):
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_checkpoint_variant_save_load(self):
|
||||
def test_checkpoint_variant_save_load_bin(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model = BertModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
|
||||
)
|
||||
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
||||
|
||||
model.save_pretrained(tmp_dir, variant="v2")
|
||||
model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False)
|
||||
# saving will create a variant checkpoint
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
|
||||
|
||||
model.save_pretrained(tmp_dir)
|
||||
model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
# saving shouldn't delete variant checkpoints
|
||||
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
|
||||
@@ -874,7 +884,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
def test_base_model_to_head_model_load(self):
|
||||
base_model = BaseModel(PretrainedConfig())
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
base_model.save_pretrained(tmp_dir)
|
||||
base_model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
|
||||
# Can load a base model in a model with head
|
||||
model = ModelWithHead.from_pretrained(tmp_dir)
|
||||
@@ -886,7 +896,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
head_state_dict = model.state_dict()
|
||||
base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"]
|
||||
base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"]
|
||||
torch.save(base_state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
|
||||
safe_save_file(base_state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The state dictionary of the model you are trying to load is corrupted."
|
||||
@@ -934,8 +944,8 @@ class ModelUtilsTest(TestCasePlus):
|
||||
|
||||
# Loading the model with the same class, we do get a warning for unexpected weights
|
||||
state_dict = model.state_dict()
|
||||
state_dict["added_key"] = state_dict["linear.weight"]
|
||||
torch.save(state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
|
||||
state_dict["added_key"] = copy.deepcopy(state_dict["linear.weight"])
|
||||
safe_save_file(state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||
with CaptureLogger(logger) as cl:
|
||||
_, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True)
|
||||
self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out)
|
||||
@@ -1072,6 +1082,54 @@ class ModelUtilsTest(TestCasePlus):
|
||||
)
|
||||
self.assertEqual(model.generation_config.transformers_version, "foo")
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_torch_from_torch(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
new_model = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
@require_safetensors
|
||||
@require_flax
|
||||
def test_safetensors_torch_from_flax(self):
|
||||
hub_model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
new_model = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
for p1, p2 in zip(hub_model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
@require_tf
|
||||
@require_safetensors
|
||||
def test_safetensors_torch_from_tf(self):
|
||||
hub_model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
new_model = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
for p1, p2 in zip(hub_model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_torch_from_torch_sharded(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB")
|
||||
new_model = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user