Don't default to other weights file when use_safetensors=True (#31874)
* Don't default to other weights file when use_safetensors=True * Add tests * Update tests/utils/test_modeling_utils.py * Add clarifying comments to tests * Update tests/utils/test_modeling_utils.py * Update tests/utils/test_modeling_utils.py
This commit is contained in:
@@ -815,6 +815,72 @@ class ModelUtilsTest(TestCasePlus):
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_checkpoint_loading_only_safetensors_available(self):
|
||||
# Test that the loading behaviour is as expected when only safetensor checkpoints are available
|
||||
# - We can load the model with use_safetensors=True
|
||||
# - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint,
|
||||
# preferring safetensors
|
||||
# - We cannot load the model with use_safetensors=False
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=True)
|
||||
|
||||
weights_index_name = ".".join(SAFE_WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"])
|
||||
weights_index_file = os.path.join(tmp_dir, weights_index_name)
|
||||
self.assertTrue(os.path.isfile(weights_index_file))
|
||||
|
||||
for i in range(1, 5):
|
||||
weights_name = f"model-0000{i}-of-00005" + ".safetensors"
|
||||
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||
self.assertTrue(os.path.isfile(weights_name_file))
|
||||
|
||||
# Setting use_safetensors=False should raise an error as the checkpoint was saved with safetensors=True
|
||||
with self.assertRaises(OSError):
|
||||
_ = BertModel.from_pretrained(tmp_dir, use_safetensors=False)
|
||||
|
||||
# We can load the model with use_safetensors=True
|
||||
new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=True)
|
||||
|
||||
# We can load the model without specifying use_safetensors
|
||||
new_model = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_checkpoint_loading_only_pytorch_bin_available(self):
|
||||
# Test that the loading behaviour is as expected when only pytorch checkpoints are available
|
||||
# - We can load the model with use_safetensors=False
|
||||
# - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint,
|
||||
# preferring safetensors but falling back to pytorch
|
||||
# - We cannot load the model with use_safetensors=True
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=False)
|
||||
|
||||
weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"])
|
||||
weights_index_file = os.path.join(tmp_dir, weights_index_name)
|
||||
self.assertTrue(os.path.isfile(weights_index_file))
|
||||
|
||||
for i in range(1, 5):
|
||||
weights_name = WEIGHTS_NAME.split(".")[0].split("_")[0] + f"_model-0000{i}-of-00005" + ".bin"
|
||||
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||
self.assertTrue(os.path.isfile(weights_name_file))
|
||||
|
||||
# Setting use_safetensors=True should raise an error as the checkpoint was saved with safetensors=False
|
||||
with self.assertRaises(OSError):
|
||||
_ = BertModel.from_pretrained(tmp_dir, use_safetensors=True)
|
||||
|
||||
# We can load the model with use_safetensors=False
|
||||
new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=False)
|
||||
|
||||
# We can load the model without specifying use_safetensors
|
||||
new_model = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_checkpoint_variant_hub(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(EnvironmentError):
|
||||
|
||||
Reference in New Issue
Block a user