[Safetensors] Add explicit flag to from pretrained (#22083)
* [Safetensors] Add explicit flag to from pretrained * add test * remove @ * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
3a35937ede
commit
f780557a34
@@ -15,6 +15,7 @@
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
@@ -119,6 +120,7 @@ if is_torch_available():
|
||||
AutoTokenizer,
|
||||
BertConfig,
|
||||
BertModel,
|
||||
CLIPTextModel,
|
||||
PreTrainedModel,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
@@ -3327,6 +3329,49 @@ class ModelUtilsTest(TestCasePlus):
|
||||
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config
|
||||
)
|
||||
|
||||
@require_safetensors
|
||||
def test_use_safetensors(self):
|
||||
# test nice error message if no safetensor files available
|
||||
with self.assertRaises(OSError) as env_error:
|
||||
AutoModel.from_pretrained("hf-internal-testing/tiny-random-RobertaModel", use_safetensors=True)
|
||||
|
||||
self.assertTrue(
|
||||
"model.safetensors or model.safetensors.index.json and thus cannot be loaded with `safetensors`"
|
||||
in str(env_error.exception)
|
||||
)
|
||||
|
||||
# test that error if only safetensors is available
|
||||
with self.assertRaises(OSError) as env_error:
|
||||
BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors", use_safetensors=False)
|
||||
|
||||
self.assertTrue("does not appear to have a file named pytorch_model.bin" in str(env_error.exception))
|
||||
|
||||
# test that only safetensors if both available and use_safetensors=False
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
CLIPTextModel.from_pretrained(
|
||||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
|
||||
subfolder="text_encoder",
|
||||
use_safetensors=False,
|
||||
cache_dir=tmp_dir,
|
||||
)
|
||||
|
||||
all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
|
||||
self.assertTrue(any(f.endswith("bin") for f in all_downloaded_files))
|
||||
self.assertFalse(any(f.endswith("safetensors") for f in all_downloaded_files))
|
||||
|
||||
# test that no safetensors if both available and use_safetensors=True
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
CLIPTextModel.from_pretrained(
|
||||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
|
||||
subfolder="text_encoder",
|
||||
use_safetensors=True,
|
||||
cache_dir=tmp_dir,
|
||||
)
|
||||
|
||||
all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
|
||||
self.assertTrue(any(f.endswith("safetensors") for f in all_downloaded_files))
|
||||
self.assertFalse(any(f.endswith("bin") for f in all_downloaded_files))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_save_and_load(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
Reference in New Issue
Block a user