Don't default to other weights file when use_safetensors=True (#31874)
* Don't default to other weights file when use_safetensors=True * Add tests * Update tests/utils/test_modeling_utils.py * Add clarifying comments to tests * Update tests/utils/test_modeling_utils.py * Update tests/utils/test_modeling_utils.py
This commit is contained in:
@@ -3395,14 +3395,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||
)
|
||||
is_sharded = True
|
||||
elif os.path.isfile(
|
||||
elif not use_safetensors and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
|
||||
):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(
|
||||
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
|
||||
)
|
||||
elif os.path.isfile(
|
||||
elif not use_safetensors and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
|
||||
):
|
||||
# Load from a sharded PyTorch checkpoint
|
||||
@@ -3411,15 +3411,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
is_sharded = True
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
elif os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
||||
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
|
||||
elif not use_safetensors and (
|
||||
os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index"))
|
||||
or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME))
|
||||
):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
|
||||
f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use"
|
||||
" `from_tf=True` to load this model from those weights."
|
||||
)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
|
||||
elif not use_safetensors and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
||||
):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
|
||||
f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`"
|
||||
|
||||
@@ -815,6 +815,72 @@ class ModelUtilsTest(TestCasePlus):
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_checkpoint_loading_only_safetensors_available(self):
|
||||
# Test that the loading behaviour is as expected when only safetensor checkpoints are available
|
||||
# - We can load the model with use_safetensors=True
|
||||
# - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint,
|
||||
# preferring safetensors
|
||||
# - We cannot load the model with use_safetensors=False
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=True)
|
||||
|
||||
weights_index_name = ".".join(SAFE_WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"])
|
||||
weights_index_file = os.path.join(tmp_dir, weights_index_name)
|
||||
self.assertTrue(os.path.isfile(weights_index_file))
|
||||
|
||||
for i in range(1, 5):
|
||||
weights_name = f"model-0000{i}-of-00005" + ".safetensors"
|
||||
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||
self.assertTrue(os.path.isfile(weights_name_file))
|
||||
|
||||
# Setting use_safetensors=False should raise an error as the checkpoint was saved with safetensors=True
|
||||
with self.assertRaises(OSError):
|
||||
_ = BertModel.from_pretrained(tmp_dir, use_safetensors=False)
|
||||
|
||||
# We can load the model with use_safetensors=True
|
||||
new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=True)
|
||||
|
||||
# We can load the model without specifying use_safetensors
|
||||
new_model = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_checkpoint_loading_only_pytorch_bin_available(self):
|
||||
# Test that the loading behaviour is as expected when only pytorch checkpoints are available
|
||||
# - We can load the model with use_safetensors=False
|
||||
# - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint,
|
||||
# preferring safetensors but falling back to pytorch
|
||||
# - We cannot load the model with use_safetensors=True
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=False)
|
||||
|
||||
weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"])
|
||||
weights_index_file = os.path.join(tmp_dir, weights_index_name)
|
||||
self.assertTrue(os.path.isfile(weights_index_file))
|
||||
|
||||
for i in range(1, 5):
|
||||
weights_name = WEIGHTS_NAME.split(".")[0].split("_")[0] + f"_model-0000{i}-of-00005" + ".bin"
|
||||
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||
self.assertTrue(os.path.isfile(weights_name_file))
|
||||
|
||||
# Setting use_safetensors=True should raise an error as the checkpoint was saved with safetensors=False
|
||||
with self.assertRaises(OSError):
|
||||
_ = BertModel.from_pretrained(tmp_dir, use_safetensors=True)
|
||||
|
||||
# We can load the model with use_safetensors=False
|
||||
new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=False)
|
||||
|
||||
# We can load the model without specifying use_safetensors
|
||||
new_model = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_checkpoint_variant_hub(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(EnvironmentError):
|
||||
|
||||
Reference in New Issue
Block a user