Don't default to other weights file when use_safetensors=True (#31874)
* Don't default to other weights file when use_safetensors=True * Add tests * Update tests/utils/test_modeling_utils.py * Add clarifying comments to tests * Update tests/utils/test_modeling_utils.py * Update tests/utils/test_modeling_utils.py
This commit is contained in:
@@ -3395,14 +3395,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||||
)
|
)
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
elif os.path.isfile(
|
elif not use_safetensors and os.path.isfile(
|
||||||
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
|
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
|
||||||
):
|
):
|
||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
archive_file = os.path.join(
|
archive_file = os.path.join(
|
||||||
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
|
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
|
||||||
)
|
)
|
||||||
elif os.path.isfile(
|
elif not use_safetensors and os.path.isfile(
|
||||||
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
|
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
|
||||||
):
|
):
|
||||||
# Load from a sharded PyTorch checkpoint
|
# Load from a sharded PyTorch checkpoint
|
||||||
@@ -3411,15 +3411,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
# At this stage we don't have a weight file so we will raise an error.
|
# At this stage we don't have a weight file so we will raise an error.
|
||||||
elif os.path.isfile(
|
elif not use_safetensors and (
|
||||||
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index"))
|
||||||
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
|
or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME))
|
||||||
|
):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
|
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
|
||||||
f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use"
|
f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use"
|
||||||
" `from_tf=True` to load this model from those weights."
|
" `from_tf=True` to load this model from those weights."
|
||||||
)
|
)
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
|
elif not use_safetensors and os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
||||||
|
):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
|
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
|
||||||
f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`"
|
f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`"
|
||||||
|
|||||||
@@ -815,6 +815,72 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
self.assertTrue(torch.allclose(p1, p2))
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
def test_checkpoint_loading_only_safetensors_available(self):
|
||||||
|
# Test that the loading behaviour is as expected when only safetensor checkpoints are available
|
||||||
|
# - We can load the model with use_safetensors=True
|
||||||
|
# - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint,
|
||||||
|
# preferring safetensors
|
||||||
|
# - We cannot load the model with use_safetensors=False
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=True)
|
||||||
|
|
||||||
|
weights_index_name = ".".join(SAFE_WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"])
|
||||||
|
weights_index_file = os.path.join(tmp_dir, weights_index_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_index_file))
|
||||||
|
|
||||||
|
for i in range(1, 5):
|
||||||
|
weights_name = f"model-0000{i}-of-00005" + ".safetensors"
|
||||||
|
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_name_file))
|
||||||
|
|
||||||
|
# Setting use_safetensors=False should raise an error as the checkpoint was saved with safetensors=True
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
_ = BertModel.from_pretrained(tmp_dir, use_safetensors=False)
|
||||||
|
|
||||||
|
# We can load the model with use_safetensors=True
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=True)
|
||||||
|
|
||||||
|
# We can load the model without specifying use_safetensors
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
def test_checkpoint_loading_only_pytorch_bin_available(self):
|
||||||
|
# Test that the loading behaviour is as expected when only pytorch checkpoints are available
|
||||||
|
# - We can load the model with use_safetensors=False
|
||||||
|
# - We can load the model without specifying use_safetensors i.e. we search for the available checkpoint,
|
||||||
|
# preferring safetensors but falling back to pytorch
|
||||||
|
# - We cannot load the model with use_safetensors=True
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, max_shard_size="50kB", safe_serialization=False)
|
||||||
|
|
||||||
|
weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["json"])
|
||||||
|
weights_index_file = os.path.join(tmp_dir, weights_index_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_index_file))
|
||||||
|
|
||||||
|
for i in range(1, 5):
|
||||||
|
weights_name = WEIGHTS_NAME.split(".")[0].split("_")[0] + f"_model-0000{i}-of-00005" + ".bin"
|
||||||
|
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_name_file))
|
||||||
|
|
||||||
|
# Setting use_safetensors=True should raise an error as the checkpoint was saved with safetensors=False
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
_ = BertModel.from_pretrained(tmp_dir, use_safetensors=True)
|
||||||
|
|
||||||
|
# We can load the model with use_safetensors=False
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir, use_safetensors=False)
|
||||||
|
|
||||||
|
# We can load the model without specifying use_safetensors
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
def test_checkpoint_variant_hub(self):
|
def test_checkpoint_variant_hub(self):
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
with self.assertRaises(EnvironmentError):
|
with self.assertRaises(EnvironmentError):
|
||||||
|
|||||||
Reference in New Issue
Block a user