[Safetensors] Add explicit flag to from pretrained (#22083)
* [Safetensors] Add explicit flag to from pretrained * add test * remove @ * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
3a35937ede
commit
f780557a34
@@ -2086,6 +2086,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
|
||||
if trust_remote_code is True:
|
||||
logger.warning(
|
||||
@@ -2222,14 +2223,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
):
|
||||
# Load from a Flax checkpoint in priority if from_flax
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
||||
elif is_safetensors_available() and os.path.isfile(
|
||||
elif use_safetensors is not False and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
|
||||
):
|
||||
# Load from a safetensors checkpoint
|
||||
archive_file = os.path.join(
|
||||
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
|
||||
)
|
||||
elif is_safetensors_available() and os.path.isfile(
|
||||
elif use_safetensors is not False and os.path.isfile(
|
||||
os.path.join(
|
||||
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||
)
|
||||
@@ -2295,7 +2296,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
filename = TF2_WEIGHTS_NAME
|
||||
elif from_flax:
|
||||
filename = FLAX_WEIGHTS_NAME
|
||||
elif is_safetensors_available():
|
||||
elif use_safetensors is not False:
|
||||
filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
|
||||
else:
|
||||
filename = _add_variant(WEIGHTS_NAME, variant)
|
||||
@@ -2328,6 +2329,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
elif use_safetensors:
|
||||
raise EnvironmentError(
|
||||
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
|
||||
)
|
||||
else:
|
||||
# This repo has no safetensors file of any kind, we switch to PyTorch.
|
||||
filename = _add_variant(WEIGHTS_NAME, variant)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
@@ -119,6 +120,7 @@ if is_torch_available():
|
||||
AutoTokenizer,
|
||||
BertConfig,
|
||||
BertModel,
|
||||
CLIPTextModel,
|
||||
PreTrainedModel,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
@@ -3327,6 +3329,49 @@ class ModelUtilsTest(TestCasePlus):
|
||||
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config
|
||||
)
|
||||
|
||||
@require_safetensors
|
||||
def test_use_safetensors(self):
|
||||
# test nice error message if no safetensor files available
|
||||
with self.assertRaises(OSError) as env_error:
|
||||
AutoModel.from_pretrained("hf-internal-testing/tiny-random-RobertaModel", use_safetensors=True)
|
||||
|
||||
self.assertTrue(
|
||||
"model.safetensors or model.safetensors.index.json and thus cannot be loaded with `safetensors`"
|
||||
in str(env_error.exception)
|
||||
)
|
||||
|
||||
# test that error if only safetensors is available
|
||||
with self.assertRaises(OSError) as env_error:
|
||||
BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors", use_safetensors=False)
|
||||
|
||||
self.assertTrue("does not appear to have a file named pytorch_model.bin" in str(env_error.exception))
|
||||
|
||||
# test that only safetensors if both available and use_safetensors=False
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
CLIPTextModel.from_pretrained(
|
||||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
|
||||
subfolder="text_encoder",
|
||||
use_safetensors=False,
|
||||
cache_dir=tmp_dir,
|
||||
)
|
||||
|
||||
all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
|
||||
self.assertTrue(any(f.endswith("bin") for f in all_downloaded_files))
|
||||
self.assertFalse(any(f.endswith("safetensors") for f in all_downloaded_files))
|
||||
|
||||
# test that no safetensors if both available and use_safetensors=True
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
CLIPTextModel.from_pretrained(
|
||||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
|
||||
subfolder="text_encoder",
|
||||
use_safetensors=True,
|
||||
cache_dir=tmp_dir,
|
||||
)
|
||||
|
||||
all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
|
||||
self.assertTrue(any(f.endswith("safetensors") for f in all_downloaded_files))
|
||||
self.assertFalse(any(f.endswith("bin") for f in all_downloaded_files))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_save_and_load(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
Reference in New Issue
Block a user