Enable safetensors conversion from PyTorch to other frameworks without the torch requirement (#27599)
* Initial commit * Requirements & tests * Tests * Tests * Rogue import * Rogue torch import * Cleanup * Apply suggestions from code review Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * bfloat16 management * Sanchit's comments * Import shield * apply suggestions from code review * correct bf16 * rebase --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
This commit is contained in:
@@ -23,13 +23,14 @@ from transformers import BertConfig, BertModel, is_flax_available, is_torch_avai
|
||||
from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
is_pt_flax_cross_test,
|
||||
is_staging_test,
|
||||
require_flax,
|
||||
require_safetensors,
|
||||
require_torch,
|
||||
)
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME, logging
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@@ -42,6 +43,9 @@ if is_flax_available():
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@require_flax
|
||||
@is_staging_test
|
||||
@@ -251,7 +255,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_torch
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt(self):
|
||||
@@ -265,7 +268,27 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt_bf16(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
import torch
|
||||
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
model.to(torch.bfloat16)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
model.save_pretrained(tmp)
|
||||
flax_model = FlaxBertModel.from_pretrained(tmp)
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_local_from_safetensors_pt(self):
|
||||
@@ -284,39 +307,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt_without_torch_installed(self):
|
||||
"""
|
||||
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
|
||||
saved in the "pt" format if torch isn't installed.
|
||||
"""
|
||||
if is_torch_available():
|
||||
# This test verifies that a correct error message is shown when loading from a pt safetensors
|
||||
# PyTorch shouldn't be installed for this to work correctly.
|
||||
return
|
||||
|
||||
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
|
||||
with self.assertRaises(ModuleNotFoundError):
|
||||
_ = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local_from_safetensors_pt_without_torch_installed(self):
|
||||
"""
|
||||
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
|
||||
saved in the "pt" format if torch isn't installed.
|
||||
"""
|
||||
if is_torch_available():
|
||||
# This test verifies that a correct error message is shown when loading from a pt safetensors
|
||||
# PyTorch shouldn't be installed for this to work correctly.
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
|
||||
|
||||
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
|
||||
with self.assertRaises(ModuleNotFoundError):
|
||||
_ = FlaxBertModel.from_pretrained(location)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
|
||||
"""
|
||||
@@ -347,6 +337,7 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
|
||||
@require_safetensors
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_flax_from_torch(self):
|
||||
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
@@ -372,3 +363,41 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
# This should not raise even if there are two types of sharded weights
|
||||
# This should discard the safetensors weights in favor of the msgpack sharded weights
|
||||
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded")
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_from_pt_bf16(self):
|
||||
# This should not raise; should be able to load bf16-serialized torch safetensors without issue
|
||||
# and without torch.
|
||||
logger = logging.get_logger("transformers.modeling_flax_utils")
|
||||
|
||||
with CaptureLogger(logger) as cl:
|
||||
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||
|
||||
self.assertTrue(
|
||||
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
|
||||
in cl.out
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_from_pt_bf16(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
model.to(torch.bfloat16)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
|
||||
logger = logging.get_logger("transformers.modeling_flax_utils")
|
||||
|
||||
with CaptureLogger(logger) as cl:
|
||||
new_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||
|
||||
self.assertTrue(
|
||||
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
|
||||
in cl.out
|
||||
)
|
||||
|
||||
flat_params_1 = flatten_dict(new_model.params)
|
||||
for value in flat_params_1.values():
|
||||
self.assertEqual(value.dtype, "bfloat16")
|
||||
|
||||
Reference in New Issue
Block a user