Safetensors serialization by default (#27064)

* Safetensors serialization by default

* First pass on the tests

* Second pass on the tests

* Third pass on the tests

* Fix TF weight loading from TF-format safetensors

* Specific encoder-decoder fixes for weight crossloading

* Add VisionEncoderDecoder fixes for TF too

* Change filename test for pt-to-tf

* One missing fix for TFVisionEncoderDecoder

* Fix the other crossload test

* Support for flax + updated tests

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Sanchit's comments

* Sanchit's comments 2

* Nico's comments

* Fix tests

* cleanup

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: Matt <rocketknight1@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Lysandre Debut
2023-10-31 19:16:49 +01:00
committed by GitHub
parent 25e6e9418c
commit 113ebf80ac
20 changed files with 433 additions and 137 deletions

View File

@@ -91,6 +91,7 @@ if is_accelerate_available():
if is_torch_available():
import torch
from safetensors.torch import save_file as safe_save_file
from torch import nn
from transformers import MODEL_MAPPING, AdaptiveEmbedding
@@ -1751,8 +1752,8 @@ class ModelTesterMixin:
# We are nuking ALL weights on file, so every parameter should
# yell on load. We're going to detect if we yell too much, or too little.
with open(os.path.join(tmp_dir, "pytorch_model.bin"), "wb") as f:
torch.save({}, f)
placeholder_dict = {"tensor": torch.tensor([1, 2])}
safe_save_file(placeholder_dict, os.path.join(tmp_dir, "model.safetensors"), metadata={"format": "pt"})
model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True)
prefix = f"{model_reloaded.base_model_prefix}."