Add safetensors to model not found error msg for default use_safetensors value (#30602)
* add safetensors to model not found error for default use_safetensors=None case * format code w/ ruff * fix assert true typo
This commit is contained in:
@@ -3270,8 +3270,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME},"
|
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)},"
|
||||||
f" {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory"
|
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory"
|
||||||
f" {pretrained_model_name_or_path}."
|
f" {pretrained_model_name_or_path}."
|
||||||
)
|
)
|
||||||
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
||||||
@@ -3417,8 +3417,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
f" {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
|
f" {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)},"
|
||||||
f" {FLAX_WEIGHTS_NAME}."
|
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
||||||
)
|
)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
||||||
|
|||||||
@@ -1001,6 +1001,26 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
self.assertTrue(any(f.endswith("safetensors") for f in all_downloaded_files))
|
self.assertTrue(any(f.endswith("safetensors") for f in all_downloaded_files))
|
||||||
self.assertFalse(any(f.endswith("bin") for f in all_downloaded_files))
|
self.assertFalse(any(f.endswith("bin") for f in all_downloaded_files))
|
||||||
|
|
||||||
|
# test no model file found when use_safetensors=None (default when safetensors package available)
|
||||||
|
with self.assertRaises(OSError) as missing_model_file_error:
|
||||||
|
BertModel.from_pretrained("hf-internal-testing/config-no-model")
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
"does not appear to have a file named pytorch_model.bin, model.safetensors,"
|
||||||
|
in str(missing_model_file_error.exception)
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(OSError) as missing_model_file_error:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with open(os.path.join(tmp_dir, "config.json"), "w") as f:
|
||||||
|
f.write("{}")
|
||||||
|
f.close()
|
||||||
|
BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
"Error no file named pytorch_model.bin, model.safetensors" in str(missing_model_file_error.exception)
|
||||||
|
)
|
||||||
|
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
def test_safetensors_save_and_load(self):
|
def test_safetensors_save_and_load(self):
|
||||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|||||||
Reference in New Issue
Block a user