Enable safetensors conversion from PyTorch to other frameworks without the torch requirement (#27599)
* Initial commit * Requirements & tests * Tests * Tests * Rogue import * Rogue torch import * Cleanup * Apply suggestions from code review Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * bfloat16 management * Sanchit's comments * Import shield * apply suggestions from code review * correct bf16 * rebase --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
This commit is contained in:
2
setup.py
2
setup.py
@@ -158,7 +158,7 @@ _deps = [
|
|||||||
"ruff==0.1.5",
|
"ruff==0.1.5",
|
||||||
"sacrebleu>=1.4.12,<2.0.0",
|
"sacrebleu>=1.4.12,<2.0.0",
|
||||||
"sacremoses",
|
"sacremoses",
|
||||||
"safetensors>=0.3.1",
|
"safetensors>=0.4.1",
|
||||||
"sagemaker>=2.31.0",
|
"sagemaker>=2.31.0",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"sentencepiece>=0.1.91,!=0.1.92",
|
"sentencepiece>=0.1.91,!=0.1.92",
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ deps = {
|
|||||||
"ruff": "ruff==0.1.5",
|
"ruff": "ruff==0.1.5",
|
||||||
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
|
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
|
||||||
"sacremoses": "sacremoses",
|
"sacremoses": "sacremoses",
|
||||||
"safetensors": "safetensors>=0.3.1",
|
"safetensors": "safetensors>=0.4.1",
|
||||||
"sagemaker": "sagemaker>=2.31.0",
|
"sagemaker": "sagemaker>=2.31.0",
|
||||||
"scikit-learn": "scikit-learn",
|
"scikit-learn": "scikit-learn",
|
||||||
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
||||||
|
|||||||
@@ -27,10 +27,13 @@ from flax.traverse_util import flatten_dict, unflatten_dict
|
|||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
from . import is_safetensors_available
|
from . import is_safetensors_available, is_torch_available
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
if is_safetensors_available():
|
if is_safetensors_available():
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from safetensors.flax import load_file as safe_load_file
|
from safetensors.flax import load_file as safe_load_file
|
||||||
@@ -48,6 +51,17 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
|||||||
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
|
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
|
||||||
):
|
):
|
||||||
"""Load pytorch checkpoints in a flax model"""
|
"""Load pytorch checkpoints in a flax model"""
|
||||||
|
|
||||||
|
if not is_sharded:
|
||||||
|
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
||||||
|
logger.info(f"Loading PyTorch weights from {pt_path}")
|
||||||
|
|
||||||
|
if pt_path.endswith(".safetensors"):
|
||||||
|
pt_state_dict = {}
|
||||||
|
with safe_open(pt_path, framework="flax") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
pt_state_dict[k] = f.get_tensor(k)
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
|
|
||||||
@@ -60,16 +74,6 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if not is_sharded:
|
|
||||||
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
|
||||||
logger.info(f"Loading PyTorch weights from {pt_path}")
|
|
||||||
|
|
||||||
if pt_path.endswith(".safetensors"):
|
|
||||||
pt_state_dict = {}
|
|
||||||
with safe_open(pt_path, framework="pt") as f:
|
|
||||||
for k in f.keys():
|
|
||||||
pt_state_dict[k] = f.get_tensor(k)
|
|
||||||
else:
|
|
||||||
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
|
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
|
||||||
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
||||||
|
|
||||||
@@ -149,21 +153,17 @@ def rename_key_and_reshape_tensor(
|
|||||||
|
|
||||||
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
||||||
# convert pytorch tensor to numpy
|
# convert pytorch tensor to numpy
|
||||||
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
|
from_bin = is_torch_available() and isinstance(next(iter(pt_state_dict.values())), torch.Tensor)
|
||||||
try:
|
bfloat16 = torch.bfloat16 if from_bin else "bfloat16"
|
||||||
import torch # noqa: F401
|
|
||||||
except (ImportError, ModuleNotFoundError):
|
|
||||||
logger.error(
|
|
||||||
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
|
|
||||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
|
||||||
" instructions."
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
|
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
|
||||||
pt_state_dict = {
|
|
||||||
k: v.numpy() if not v.dtype == torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
|
if from_bin:
|
||||||
}
|
for k, v in pt_state_dict.items():
|
||||||
|
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
|
||||||
|
if v.dtype == bfloat16:
|
||||||
|
v = v.float()
|
||||||
|
pt_state_dict[k] = v.numpy()
|
||||||
|
|
||||||
model_prefix = flax_model.base_model_prefix
|
model_prefix = flax_model.base_model_prefix
|
||||||
|
|
||||||
@@ -191,7 +191,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
|||||||
# Need to change some parameters name to match Flax names
|
# Need to change some parameters name to match Flax names
|
||||||
for pt_key, pt_tensor in pt_state_dict.items():
|
for pt_key, pt_tensor in pt_state_dict.items():
|
||||||
pt_tuple_key = tuple(pt_key.split("."))
|
pt_tuple_key = tuple(pt_key.split("."))
|
||||||
is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16
|
is_bfloat_16 = weight_dtypes[pt_key] == bfloat16
|
||||||
|
|
||||||
# remove base model prefix if necessary
|
# remove base model prefix if necessary
|
||||||
has_base_model_prefix = pt_tuple_key[0] == model_prefix
|
has_base_model_prefix = pt_tuple_key[0] == model_prefix
|
||||||
@@ -229,7 +229,6 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
|||||||
flax_state_dict[("params",) + flax_key] = (
|
flax_state_dict[("params",) + flax_key] = (
|
||||||
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
|
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# also add unexpected weight so that warning is thrown
|
# also add unexpected weight so that warning is thrown
|
||||||
flax_state_dict[flax_key] = (
|
flax_state_dict[flax_key] = (
|
||||||
|
|||||||
@@ -23,13 +23,14 @@ from transformers import BertConfig, BertModel, is_flax_available, is_torch_avai
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
TOKEN,
|
TOKEN,
|
||||||
USER,
|
USER,
|
||||||
|
CaptureLogger,
|
||||||
is_pt_flax_cross_test,
|
is_pt_flax_cross_test,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
require_flax,
|
require_flax,
|
||||||
require_safetensors,
|
require_safetensors,
|
||||||
require_torch,
|
require_torch,
|
||||||
)
|
)
|
||||||
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME, logging
|
||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
@@ -42,6 +43,9 @@ if is_flax_available():
|
|||||||
|
|
||||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
@@ -251,7 +255,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||||
|
|
||||||
@require_torch
|
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
@is_pt_flax_cross_test
|
@is_pt_flax_cross_test
|
||||||
def test_safetensors_load_from_hub_from_safetensors_pt(self):
|
def test_safetensors_load_from_hub_from_safetensors_pt(self):
|
||||||
@@ -265,7 +268,27 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
|||||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
@require_torch
|
@require_torch
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_safetensors_load_from_hub_from_safetensors_pt_bf16(self):
|
||||||
|
"""
|
||||||
|
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||||
|
saved in the "pt" format.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||||
|
model.to(torch.bfloat16)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
model.save_pretrained(tmp)
|
||||||
|
flax_model = FlaxBertModel.from_pretrained(tmp)
|
||||||
|
|
||||||
|
# Can load from the PyTorch-formatted checkpoint
|
||||||
|
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||||
|
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||||
|
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
@is_pt_flax_cross_test
|
@is_pt_flax_cross_test
|
||||||
def test_safetensors_load_from_local_from_safetensors_pt(self):
|
def test_safetensors_load_from_local_from_safetensors_pt(self):
|
||||||
@@ -284,39 +307,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||||
|
|
||||||
@require_safetensors
|
|
||||||
def test_safetensors_load_from_hub_from_safetensors_pt_without_torch_installed(self):
|
|
||||||
"""
|
|
||||||
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
|
|
||||||
saved in the "pt" format if torch isn't installed.
|
|
||||||
"""
|
|
||||||
if is_torch_available():
|
|
||||||
# This test verifies that a correct error message is shown when loading from a pt safetensors
|
|
||||||
# PyTorch shouldn't be installed for this to work correctly.
|
|
||||||
return
|
|
||||||
|
|
||||||
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
|
|
||||||
with self.assertRaises(ModuleNotFoundError):
|
|
||||||
_ = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
|
||||||
|
|
||||||
@require_safetensors
|
|
||||||
def test_safetensors_load_from_local_from_safetensors_pt_without_torch_installed(self):
|
|
||||||
"""
|
|
||||||
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
|
|
||||||
saved in the "pt" format if torch isn't installed.
|
|
||||||
"""
|
|
||||||
if is_torch_available():
|
|
||||||
# This test verifies that a correct error message is shown when loading from a pt safetensors
|
|
||||||
# PyTorch shouldn't be installed for this to work correctly.
|
|
||||||
return
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
|
||||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
|
|
||||||
|
|
||||||
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
|
|
||||||
with self.assertRaises(ModuleNotFoundError):
|
|
||||||
_ = FlaxBertModel.from_pretrained(location)
|
|
||||||
|
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
|
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
|
||||||
"""
|
"""
|
||||||
@@ -347,6 +337,7 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
@require_torch
|
@require_torch
|
||||||
|
@is_pt_flax_cross_test
|
||||||
def test_safetensors_flax_from_torch(self):
|
def test_safetensors_flax_from_torch(self):
|
||||||
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||||
@@ -372,3 +363,41 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
|||||||
# This should not raise even if there are two types of sharded weights
|
# This should not raise even if there are two types of sharded weights
|
||||||
# This should discard the safetensors weights in favor of the msgpack sharded weights
|
# This should discard the safetensors weights in favor of the msgpack sharded weights
|
||||||
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded")
|
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded")
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_from_pt_bf16(self):
|
||||||
|
# This should not raise; should be able to load bf16-serialized torch safetensors without issue
|
||||||
|
# and without torch.
|
||||||
|
logger = logging.get_logger("transformers.modeling_flax_utils")
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
|
||||||
|
in cl.out
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_safetensors
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_from_pt_bf16(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||||
|
model.to(torch.bfloat16)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||||
|
|
||||||
|
logger = logging.get_logger("transformers.modeling_flax_utils")
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
new_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
|
||||||
|
in cl.out
|
||||||
|
)
|
||||||
|
|
||||||
|
flat_params_1 = flatten_dict(new_model.params)
|
||||||
|
for value in flat_params_1.values():
|
||||||
|
self.assertEqual(value.dtype, "bfloat16")
|
||||||
|
|||||||
Reference in New Issue
Block a user