Safetensors serialization by default (#27064)
* Safetensors serialization by default * First pass on the tests * Second pass on the tests * Third pass on the tests * Fix TF weight loading from TF-format safetensors * Specific encoder-decoder fixes for weight crossloading * Add VisionEncoderDecoder fixes for TF too * Change filename test for pt-to-tf * One missing fix for TFVisionEncoderDecoder * Fix the other crossload test * Support for flax + updated tests * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Sanchit's comments * Sanchit's comments 2 * Nico's comments * Fix tests * cleanup * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -91,6 +91,7 @@ if is_accelerate_available():
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
from torch import nn
|
||||
|
||||
from transformers import MODEL_MAPPING, AdaptiveEmbedding
|
||||
@@ -1751,8 +1752,8 @@ class ModelTesterMixin:
|
||||
|
||||
# We are nuking ALL weights on file, so every parameter should
|
||||
# yell on load. We're going to detect if we yell too much, or too little.
|
||||
with open(os.path.join(tmp_dir, "pytorch_model.bin"), "wb") as f:
|
||||
torch.save({}, f)
|
||||
placeholder_dict = {"tensor": torch.tensor([1, 2])}
|
||||
safe_save_file(placeholder_dict, os.path.join(tmp_dir, "model.safetensors"), metadata={"format": "pt"})
|
||||
model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True)
|
||||
|
||||
prefix = f"{model_reloaded.base_model_prefix}."
|
||||
|
||||
Reference in New Issue
Block a user