[Safetensors] Add explicit flag to from pretrained (#22083)

* [Safetensors] Add explicit  flag to from pretrained

* add test

* remove @

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2023-03-13 21:39:06 +01:00
committed by GitHub
parent 3a35937ede
commit f780557a34
2 changed files with 53 additions and 3 deletions

View File

@@ -2086,6 +2086,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
if trust_remote_code is True:
logger.warning(
@@ -2222,14 +2223,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
):
# Load from a Flax checkpoint in priority if from_flax
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif is_safetensors_available() and os.path.isfile(
elif use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
):
# Load from a safetensors checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
)
elif is_safetensors_available() and os.path.isfile(
elif use_safetensors is not False and os.path.isfile(
os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
)
@@ -2295,7 +2296,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
filename = TF2_WEIGHTS_NAME
elif from_flax:
filename = FLAX_WEIGHTS_NAME
elif is_safetensors_available():
elif use_safetensors is not False:
filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
else:
filename = _add_variant(WEIGHTS_NAME, variant)
@@ -2328,6 +2329,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
if resolved_archive_file is not None:
is_sharded = True
elif use_safetensors:
raise EnvironmentError(
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
)
else:
# This repo has no safetensors file of any kind, we switch to PyTorch.
filename = _add_variant(WEIGHTS_NAME, variant)

View File

@@ -15,6 +15,7 @@
import copy
import gc
import glob
import inspect
import json
import os
@@ -119,6 +120,7 @@ if is_torch_available():
AutoTokenizer,
BertConfig,
BertModel,
CLIPTextModel,
PreTrainedModel,
T5Config,
T5ForConditionalGeneration,
@@ -3327,6 +3329,49 @@ class ModelUtilsTest(TestCasePlus):
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config
)
@require_safetensors
def test_use_safetensors(self):
# test nice error message if no safetensor files available
with self.assertRaises(OSError) as env_error:
AutoModel.from_pretrained("hf-internal-testing/tiny-random-RobertaModel", use_safetensors=True)
self.assertTrue(
"model.safetensors or model.safetensors.index.json and thus cannot be loaded with `safetensors`"
in str(env_error.exception)
)
# test that error if only safetensors is available
with self.assertRaises(OSError) as env_error:
BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors", use_safetensors=False)
self.assertTrue("does not appear to have a file named pytorch_model.bin" in str(env_error.exception))
# test that only safetensors if both available and use_safetensors=False
with tempfile.TemporaryDirectory() as tmp_dir:
CLIPTextModel.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
subfolder="text_encoder",
use_safetensors=False,
cache_dir=tmp_dir,
)
all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
self.assertTrue(any(f.endswith("bin") for f in all_downloaded_files))
self.assertFalse(any(f.endswith("safetensors") for f in all_downloaded_files))
# test that no safetensors if both available and use_safetensors=True
with tempfile.TemporaryDirectory() as tmp_dir:
CLIPTextModel.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
subfolder="text_encoder",
use_safetensors=True,
cache_dir=tmp_dir,
)
all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
self.assertTrue(any(f.endswith("safetensors") for f in all_downloaded_files))
self.assertFalse(any(f.endswith("bin") for f in all_downloaded_files))
@require_safetensors
def test_safetensors_save_and_load(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")