Default to msgpack for safetensors (#27460)
* Default to msgpack for safetensors * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -50,7 +50,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
|||||||
"""Load pytorch checkpoints in a flax model"""
|
"""Load pytorch checkpoints in a flax model"""
|
||||||
try:
|
try:
|
||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
except ImportError:
|
except (ImportError, ModuleNotFoundError):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
|
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
|
||||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
||||||
@@ -150,7 +150,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
|||||||
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
|
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
|
||||||
try:
|
try:
|
||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
except ImportError:
|
except (ImportError, ModuleNotFoundError):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
|
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
|
||||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
||||||
@@ -349,7 +349,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
except ImportError:
|
except (ImportError, ModuleNotFoundError):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see"
|
"Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see"
|
||||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
||||||
|
|||||||
@@ -721,7 +721,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
if is_safetensors_available() and os.path.isfile(
|
if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
|
||||||
|
# Load from a Flax checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
||||||
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)):
|
||||||
|
# Load from a sharded Flax checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
|
||||||
|
is_sharded = True
|
||||||
|
elif is_safetensors_available() and os.path.isfile(
|
||||||
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
|
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
|
||||||
):
|
):
|
||||||
# Load from a safetensors checkpoint
|
# Load from a safetensors checkpoint
|
||||||
@@ -735,13 +742,6 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
# Load from a sharded pytorch checkpoint
|
# Load from a sharded pytorch checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
|
|
||||||
# Load from a Flax checkpoint
|
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)):
|
|
||||||
# Load from a sharded Flax checkpoint
|
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
|
|
||||||
is_sharded = True
|
|
||||||
# At this stage we don't have a weight file so we will raise an error.
|
# At this stage we don't have a weight file so we will raise an error.
|
||||||
elif is_safetensors_available() and os.path.isfile(
|
elif is_safetensors_available() and os.path.isfile(
|
||||||
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
@@ -770,8 +770,6 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
else:
|
else:
|
||||||
if from_pt:
|
if from_pt:
|
||||||
filename = WEIGHTS_NAME
|
filename = WEIGHTS_NAME
|
||||||
elif is_safetensors_available():
|
|
||||||
filename = SAFE_WEIGHTS_NAME
|
|
||||||
else:
|
else:
|
||||||
filename = FLAX_WEIGHTS_NAME
|
filename = FLAX_WEIGHTS_NAME
|
||||||
|
|
||||||
@@ -792,22 +790,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
}
|
}
|
||||||
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||||
|
|
||||||
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
|
||||||
# result when internet is up, the repo and revision exist, but the file does not.
|
|
||||||
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
|
|
||||||
# Did not find the safetensors file, let's fallback to Flax.
|
|
||||||
# No support for sharded safetensors yet, so we'll raise an error if that's all we find.
|
|
||||||
filename = FLAX_WEIGHTS_NAME
|
|
||||||
resolved_archive_file = cached_file(
|
|
||||||
pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **cached_file_kwargs
|
|
||||||
)
|
|
||||||
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
|
|
||||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
|
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
|
||||||
resolved_archive_file = cached_file(
|
resolved_archive_file = cached_file(
|
||||||
pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||||
)
|
)
|
||||||
if resolved_archive_file is not None:
|
if resolved_archive_file is not None:
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
|
|
||||||
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
|
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
|
||||||
if resolved_archive_file is None and from_pt:
|
if resolved_archive_file is None and from_pt:
|
||||||
resolved_archive_file = cached_file(
|
resolved_archive_file = cached_file(
|
||||||
@@ -815,6 +805,17 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
)
|
)
|
||||||
if resolved_archive_file is not None:
|
if resolved_archive_file is not None:
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
|
|
||||||
|
# If we still haven't found anything, look for `safetensors`.
|
||||||
|
if resolved_archive_file is None:
|
||||||
|
# No support for sharded safetensors yet, so we'll raise an error if that's all we find.
|
||||||
|
filename = SAFE_WEIGHTS_NAME
|
||||||
|
resolved_archive_file = cached_file(
|
||||||
|
pretrained_model_name_or_path, SAFE_WEIGHTS_NAME, **cached_file_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
||||||
|
# result when internet is up, the repo and revision exist, but the file does not.
|
||||||
if resolved_archive_file is None:
|
if resolved_archive_file is None:
|
||||||
# Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error
|
# Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error
|
||||||
# message.
|
# message.
|
||||||
|
|||||||
@@ -19,8 +19,16 @@ import numpy as np
|
|||||||
from huggingface_hub import HfFolder, delete_repo, snapshot_download
|
from huggingface_hub import HfFolder, delete_repo, snapshot_download
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
from transformers import BertConfig, BertModel, is_flax_available
|
from transformers import BertConfig, BertModel, is_flax_available, is_torch_available
|
||||||
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_flax, require_safetensors, require_torch
|
from transformers.testing_utils import (
|
||||||
|
TOKEN,
|
||||||
|
USER,
|
||||||
|
is_pt_flax_cross_test,
|
||||||
|
is_staging_test,
|
||||||
|
require_flax,
|
||||||
|
require_safetensors,
|
||||||
|
require_torch,
|
||||||
|
)
|
||||||
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||||
|
|
||||||
|
|
||||||
@@ -202,6 +210,7 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
@require_torch
|
@require_torch
|
||||||
|
@is_pt_flax_cross_test
|
||||||
def test_safetensors_save_and_load_pt_to_flax(self):
|
def test_safetensors_save_and_load_pt_to_flax(self):
|
||||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
|
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
|
||||||
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
@@ -218,21 +227,114 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
def test_safetensors_load_from_hub(self):
|
def test_safetensors_load_from_hub(self):
|
||||||
|
"""
|
||||||
|
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
|
||||||
|
"""
|
||||||
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
|
||||||
# Can load from the Flax-formatted checkpoint
|
# Can load from the Flax-formatted checkpoint
|
||||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-only")
|
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-only")
|
||||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_local(self):
|
||||||
|
"""
|
||||||
|
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
|
||||||
|
"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
location = snapshot_download("hf-internal-testing/tiny-bert-flax-only", cache_dir=tmp)
|
||||||
|
flax_model = FlaxBertModel.from_pretrained(location)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
location = snapshot_download("hf-internal-testing/tiny-bert-flax-safetensors-only", cache_dir=tmp)
|
||||||
|
safetensors_model = FlaxBertModel.from_pretrained(location)
|
||||||
|
|
||||||
|
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
def test_safetensors_load_from_hub_flax_and_pt(self):
|
@is_pt_flax_cross_test
|
||||||
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
def test_safetensors_load_from_hub_from_safetensors_pt(self):
|
||||||
|
"""
|
||||||
|
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||||
|
saved in the "pt" format.
|
||||||
|
"""
|
||||||
|
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-msgpack")
|
||||||
|
|
||||||
# Can load from the PyTorch-formatted checkpoint
|
# Can load from the PyTorch-formatted checkpoint
|
||||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only", from_pt=True)
|
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_safetensors
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_safetensors_load_from_local_from_safetensors_pt(self):
|
||||||
|
"""
|
||||||
|
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||||
|
saved in the "pt" format.
|
||||||
|
"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp)
|
||||||
|
flax_model = FlaxBertModel.from_pretrained(location)
|
||||||
|
|
||||||
|
# Can load from the PyTorch-formatted checkpoint
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
|
||||||
|
safetensors_model = FlaxBertModel.from_pretrained(location)
|
||||||
|
|
||||||
|
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_hub_from_safetensors_pt_without_torch_installed(self):
|
||||||
|
"""
|
||||||
|
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
|
||||||
|
saved in the "pt" format if torch isn't installed.
|
||||||
|
"""
|
||||||
|
if is_torch_available():
|
||||||
|
# This test verifies that a correct error message is shown when loading from a pt safetensors
|
||||||
|
# PyTorch shouldn't be installed for this to work correctly.
|
||||||
|
return
|
||||||
|
|
||||||
|
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
|
||||||
|
with self.assertRaises(ModuleNotFoundError):
|
||||||
|
_ = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_local_from_safetensors_pt_without_torch_installed(self):
|
||||||
|
"""
|
||||||
|
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
|
||||||
|
saved in the "pt" format if torch isn't installed.
|
||||||
|
"""
|
||||||
|
if is_torch_available():
|
||||||
|
# This test verifies that a correct error message is shown when loading from a pt safetensors
|
||||||
|
# PyTorch shouldn't be installed for this to work correctly.
|
||||||
|
return
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
|
||||||
|
|
||||||
|
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
|
||||||
|
with self.assertRaises(ModuleNotFoundError):
|
||||||
|
_ = FlaxBertModel.from_pretrained(location)
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
|
||||||
|
"""
|
||||||
|
This test checks that we'll first download msgpack weights before safetensors
|
||||||
|
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
|
||||||
|
"""
|
||||||
|
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_local_msgpack_before_safetensors(self):
|
||||||
|
"""
|
||||||
|
This test checks that we'll first download msgpack weights before safetensors
|
||||||
|
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
|
||||||
|
"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
|
||||||
|
FlaxBertModel.from_pretrained(location)
|
||||||
|
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
def test_safetensors_flax_from_flax(self):
|
def test_safetensors_flax_from_flax(self):
|
||||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
|||||||
@@ -535,6 +535,71 @@ class TFModelUtilsTest(unittest.TestCase):
|
|||||||
# This should discard the safetensors weights in favor of the .h5 sharded weights
|
# This should discard the safetensors weights in favor of the .h5 sharded weights
|
||||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
|
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_local(self):
|
||||||
|
"""
|
||||||
|
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
|
||||||
|
"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
location = snapshot_download("hf-internal-testing/tiny-bert-tf-only", cache_dir=tmp)
|
||||||
|
tf_model = TFBertModel.from_pretrained(location)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
location = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-only", cache_dir=tmp)
|
||||||
|
safetensors_model = TFBertModel.from_pretrained(location)
|
||||||
|
|
||||||
|
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_hub_from_safetensors_pt(self):
|
||||||
|
"""
|
||||||
|
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||||
|
saved in the "pt" format.
|
||||||
|
"""
|
||||||
|
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-h5")
|
||||||
|
|
||||||
|
# Can load from the PyTorch-formatted checkpoint
|
||||||
|
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||||
|
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_local_from_safetensors_pt(self):
|
||||||
|
"""
|
||||||
|
This test checks that we can load safetensors from a local checkpoint that only has those
|
||||||
|
saved in the "pt" format.
|
||||||
|
"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
location = snapshot_download("hf-internal-testing/tiny-bert-h5", cache_dir=tmp)
|
||||||
|
tf_model = TFBertModel.from_pretrained(location)
|
||||||
|
|
||||||
|
# Can load from the PyTorch-formatted checkpoint
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
|
||||||
|
safetensors_model = TFBertModel.from_pretrained(location)
|
||||||
|
|
||||||
|
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_hub_h5_before_safetensors(self):
|
||||||
|
"""
|
||||||
|
This test checks that we'll first download h5 weights before safetensors
|
||||||
|
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
|
||||||
|
"""
|
||||||
|
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_local_h5_before_safetensors(self):
|
||||||
|
"""
|
||||||
|
This test checks that we'll first download h5 weights before safetensors
|
||||||
|
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
|
||||||
|
"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
|
||||||
|
TFBertModel.from_pretrained(location)
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user