Enable safetensors conversion from PyTorch to other frameworks without the torch requirement (#27599)

* Initial commit

* Requirements & tests

* Tests

* Tests

* Rogue import

* Rogue torch import

* Cleanup

* Apply suggestions from code review

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* bfloat16 management

* Sanchit's comments

* Import shield

* apply suggestions from code review

* correct bf16

* rebase

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
This commit is contained in:
Lysandre Debut
2024-01-23 10:28:23 +01:00
committed by GitHub
parent 039866094c
commit 008a6a2208
4 changed files with 94 additions and 66 deletions

View File

@@ -158,7 +158,7 @@ _deps = [
"ruff==0.1.5", "ruff==0.1.5",
"sacrebleu>=1.4.12,<2.0.0", "sacrebleu>=1.4.12,<2.0.0",
"sacremoses", "sacremoses",
"safetensors>=0.3.1", "safetensors>=0.4.1",
"sagemaker>=2.31.0", "sagemaker>=2.31.0",
"scikit-learn", "scikit-learn",
"sentencepiece>=0.1.91,!=0.1.92", "sentencepiece>=0.1.91,!=0.1.92",

View File

@@ -64,7 +64,7 @@ deps = {
"ruff": "ruff==0.1.5", "ruff": "ruff==0.1.5",
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0", "sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
"sacremoses": "sacremoses", "sacremoses": "sacremoses",
"safetensors": "safetensors>=0.3.1", "safetensors": "safetensors>=0.4.1",
"sagemaker": "sagemaker>=2.31.0", "sagemaker": "sagemaker>=2.31.0",
"scikit-learn": "scikit-learn", "scikit-learn": "scikit-learn",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",

View File

@@ -27,10 +27,13 @@ from flax.traverse_util import flatten_dict, unflatten_dict
import transformers import transformers
from . import is_safetensors_available from . import is_safetensors_available, is_torch_available
from .utils import logging from .utils import logging
if is_torch_available():
import torch
if is_safetensors_available(): if is_safetensors_available():
from safetensors import safe_open from safetensors import safe_open
from safetensors.flax import load_file as safe_load_file from safetensors.flax import load_file as safe_load_file
@@ -48,17 +51,6 @@ def load_pytorch_checkpoint_in_flax_state_dict(
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
): ):
"""Load pytorch checkpoints in a flax model""" """Load pytorch checkpoints in a flax model"""
try:
import torch # noqa: F401
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise
if not is_sharded: if not is_sharded:
pt_path = os.path.abspath(pytorch_checkpoint_path) pt_path = os.path.abspath(pytorch_checkpoint_path)
@@ -66,12 +58,24 @@ def load_pytorch_checkpoint_in_flax_state_dict(
if pt_path.endswith(".safetensors"): if pt_path.endswith(".safetensors"):
pt_state_dict = {} pt_state_dict = {}
with safe_open(pt_path, framework="pt") as f: with safe_open(pt_path, framework="flax") as f:
for k in f.keys(): for k in f.keys():
pt_state_dict[k] = f.get_tensor(k) pt_state_dict[k] = f.get_tensor(k)
else: else:
try:
import torch # noqa: F401
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13) pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
else: else:
@@ -149,21 +153,17 @@ def rename_key_and_reshape_tensor(
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# convert pytorch tensor to numpy # convert pytorch tensor to numpy
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision from_bin = is_torch_available() and isinstance(next(iter(pt_state_dict.values())), torch.Tensor)
try: bfloat16 = torch.bfloat16 if from_bin else "bfloat16"
import torch # noqa: F401
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()} weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
pt_state_dict = {
k: v.numpy() if not v.dtype == torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items() if from_bin:
} for k, v in pt_state_dict.items():
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
if v.dtype == bfloat16:
v = v.float()
pt_state_dict[k] = v.numpy()
model_prefix = flax_model.base_model_prefix model_prefix = flax_model.base_model_prefix
@@ -191,7 +191,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# Need to change some parameters name to match Flax names # Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items(): for pt_key, pt_tensor in pt_state_dict.items():
pt_tuple_key = tuple(pt_key.split(".")) pt_tuple_key = tuple(pt_key.split("."))
is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16 is_bfloat_16 = weight_dtypes[pt_key] == bfloat16
# remove base model prefix if necessary # remove base model prefix if necessary
has_base_model_prefix = pt_tuple_key[0] == model_prefix has_base_model_prefix = pt_tuple_key[0] == model_prefix
@@ -229,7 +229,6 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
flax_state_dict[("params",) + flax_key] = ( flax_state_dict[("params",) + flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
) )
else: else:
# also add unexpected weight so that warning is thrown # also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = ( flax_state_dict[flax_key] = (

View File

@@ -23,13 +23,14 @@ from transformers import BertConfig, BertModel, is_flax_available, is_torch_avai
from transformers.testing_utils import ( from transformers.testing_utils import (
TOKEN, TOKEN,
USER, USER,
CaptureLogger,
is_pt_flax_cross_test, is_pt_flax_cross_test,
is_staging_test, is_staging_test,
require_flax, require_flax,
require_safetensors, require_safetensors,
require_torch, require_torch,
) )
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME, logging
if is_flax_available(): if is_flax_available():
@@ -42,6 +43,9 @@ if is_flax_available():
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
if is_torch_available():
import torch
@require_flax @require_flax
@is_staging_test @is_staging_test
@@ -251,7 +255,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
self.assertTrue(check_models_equal(flax_model, safetensors_model)) self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_torch
@require_safetensors @require_safetensors
@is_pt_flax_cross_test @is_pt_flax_cross_test
def test_safetensors_load_from_hub_from_safetensors_pt(self): def test_safetensors_load_from_hub_from_safetensors_pt(self):
@@ -265,7 +268,27 @@ class FlaxModelUtilsTest(unittest.TestCase):
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors") safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
self.assertTrue(check_models_equal(flax_model, safetensors_model)) self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
@require_torch @require_torch
@is_pt_flax_cross_test
def test_safetensors_load_from_hub_from_safetensors_pt_bf16(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
import torch
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
model.to(torch.bfloat16)
with tempfile.TemporaryDirectory() as tmp:
model.save_pretrained(tmp)
flax_model = FlaxBertModel.from_pretrained(tmp)
# Can load from the PyTorch-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors @require_safetensors
@is_pt_flax_cross_test @is_pt_flax_cross_test
def test_safetensors_load_from_local_from_safetensors_pt(self): def test_safetensors_load_from_local_from_safetensors_pt(self):
@@ -284,39 +307,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
self.assertTrue(check_models_equal(flax_model, safetensors_model)) self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
def test_safetensors_load_from_hub_from_safetensors_pt_without_torch_installed(self):
"""
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
saved in the "pt" format if torch isn't installed.
"""
if is_torch_available():
# This test verifies that a correct error message is shown when loading from a pt safetensors
# PyTorch shouldn't be installed for this to work correctly.
return
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
with self.assertRaises(ModuleNotFoundError):
_ = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
@require_safetensors
def test_safetensors_load_from_local_from_safetensors_pt_without_torch_installed(self):
"""
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
saved in the "pt" format if torch isn't installed.
"""
if is_torch_available():
# This test verifies that a correct error message is shown when loading from a pt safetensors
# PyTorch shouldn't be installed for this to work correctly.
return
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
with self.assertRaises(ModuleNotFoundError):
_ = FlaxBertModel.from_pretrained(location)
@require_safetensors @require_safetensors
def test_safetensors_load_from_hub_msgpack_before_safetensors(self): def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
""" """
@@ -347,6 +337,7 @@ class FlaxModelUtilsTest(unittest.TestCase):
@require_safetensors @require_safetensors
@require_torch @require_torch
@is_pt_flax_cross_test
def test_safetensors_flax_from_torch(self): def test_safetensors_flax_from_torch(self):
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only") model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
@@ -372,3 +363,41 @@ class FlaxModelUtilsTest(unittest.TestCase):
# This should not raise even if there are two types of sharded weights # This should not raise even if there are two types of sharded weights
# This should discard the safetensors weights in favor of the msgpack sharded weights # This should discard the safetensors weights in favor of the msgpack sharded weights
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded") FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded")
@require_safetensors
def test_safetensors_from_pt_bf16(self):
# This should not raise; should be able to load bf16-serialized torch safetensors without issue
# and without torch.
logger = logging.get_logger("transformers.modeling_flax_utils")
with CaptureLogger(logger) as cl:
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
self.assertTrue(
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
in cl.out
)
@require_torch
@require_safetensors
@is_pt_flax_cross_test
def test_from_pt_bf16(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
model.to(torch.bfloat16)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=False)
logger = logging.get_logger("transformers.modeling_flax_utils")
with CaptureLogger(logger) as cl:
new_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
self.assertTrue(
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
in cl.out
)
flat_params_1 = flatten_dict(new_model.params)
for value in flat_params_1.values():
self.assertEqual(value.dtype, "bfloat16")