From 008a6a2208ad0a04f8a610a2504613ffcd8b3296 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Tue, 23 Jan 2024 10:28:23 +0100 Subject: [PATCH] Enable safetensors conversion from PyTorch to other frameworks without the torch requirement (#27599) * Initial commit * Requirements & tests * Tests * Tests * Rogue import * Rogue torch import * Cleanup * Apply suggestions from code review Co-authored-by: Nicolas Patry * bfloat16 management * Sanchit's comments * Import shield * apply suggestions from code review * correct bf16 * rebase --------- Co-authored-by: Nicolas Patry Co-authored-by: sanchit-gandhi --- setup.py | 2 +- src/transformers/dependency_versions_table.py | 2 +- .../modeling_flax_pytorch_utils.py | 57 ++++++----- tests/test_modeling_flax_utils.py | 99 ++++++++++++------- 4 files changed, 94 insertions(+), 66 deletions(-) diff --git a/setup.py b/setup.py index 91ad923ec3..f7e31559ac 100644 --- a/setup.py +++ b/setup.py @@ -158,7 +158,7 @@ _deps = [ "ruff==0.1.5", "sacrebleu>=1.4.12,<2.0.0", "sacremoses", - "safetensors>=0.3.1", + "safetensors>=0.4.1", "sagemaker>=2.31.0", "scikit-learn", "sentencepiece>=0.1.91,!=0.1.92", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 0ac5d26b6c..cecdaf3419 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -64,7 +64,7 @@ deps = { "ruff": "ruff==0.1.5", "sacrebleu": "sacrebleu>=1.4.12,<2.0.0", "sacremoses": "sacremoses", - "safetensors": "safetensors>=0.3.1", + "safetensors": "safetensors>=0.4.1", "sagemaker": "sagemaker>=2.31.0", "scikit-learn": "scikit-learn", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 830d222928..87701c50f0 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -27,10 +27,13 @@ from flax.traverse_util import flatten_dict, unflatten_dict import transformers -from . import is_safetensors_available +from . import is_safetensors_available, is_torch_available from .utils import logging +if is_torch_available(): + import torch + if is_safetensors_available(): from safetensors import safe_open from safetensors.flax import load_file as safe_load_file @@ -48,17 +51,6 @@ def load_pytorch_checkpoint_in_flax_state_dict( flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False ): """Load pytorch checkpoints in a flax model""" - try: - import torch # noqa: F401 - - from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401 - except (ImportError, ModuleNotFoundError): - logger.error( - "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see" - " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" - " instructions." - ) - raise if not is_sharded: pt_path = os.path.abspath(pytorch_checkpoint_path) @@ -66,12 +58,24 @@ def load_pytorch_checkpoint_in_flax_state_dict( if pt_path.endswith(".safetensors"): pt_state_dict = {} - with safe_open(pt_path, framework="pt") as f: + with safe_open(pt_path, framework="flax") as f: for k in f.keys(): pt_state_dict[k] = f.get_tensor(k) else: + try: + import torch # noqa: F401 + + from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401 + except (ImportError, ModuleNotFoundError): + logger.error( + "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see" + " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" + " instructions." + ) + raise + pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13) - logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") + logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) else: @@ -149,21 +153,17 @@ def rename_key_and_reshape_tensor( def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): # convert pytorch tensor to numpy - # numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision - try: - import torch # noqa: F401 - except (ImportError, ModuleNotFoundError): - logger.error( - "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see" - " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" - " instructions." - ) - raise + from_bin = is_torch_available() and isinstance(next(iter(pt_state_dict.values())), torch.Tensor) + bfloat16 = torch.bfloat16 if from_bin else "bfloat16" weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()} - pt_state_dict = { - k: v.numpy() if not v.dtype == torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items() - } + + if from_bin: + for k, v in pt_state_dict.items(): + # numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision + if v.dtype == bfloat16: + v = v.float() + pt_state_dict[k] = v.numpy() model_prefix = flax_model.base_model_prefix @@ -191,7 +191,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): pt_tuple_key = tuple(pt_key.split(".")) - is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16 + is_bfloat_16 = weight_dtypes[pt_key] == bfloat16 # remove base model prefix if necessary has_base_model_prefix = pt_tuple_key[0] == model_prefix @@ -229,7 +229,6 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): flax_state_dict[("params",) + flax_key] = ( jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) ) - else: # also add unexpected weight so that warning is thrown flax_state_dict[flax_key] = ( diff --git a/tests/test_modeling_flax_utils.py b/tests/test_modeling_flax_utils.py index e668b43532..0309a3bd8f 100644 --- a/tests/test_modeling_flax_utils.py +++ b/tests/test_modeling_flax_utils.py @@ -23,13 +23,14 @@ from transformers import BertConfig, BertModel, is_flax_available, is_torch_avai from transformers.testing_utils import ( TOKEN, USER, + CaptureLogger, is_pt_flax_cross_test, is_staging_test, require_flax, require_safetensors, require_torch, ) -from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME +from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME, logging if is_flax_available(): @@ -42,6 +43,9 @@ if is_flax_available(): os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 +if is_torch_available(): + import torch + @require_flax @is_staging_test @@ -251,7 +255,6 @@ class FlaxModelUtilsTest(unittest.TestCase): self.assertTrue(check_models_equal(flax_model, safetensors_model)) - @require_torch @require_safetensors @is_pt_flax_cross_test def test_safetensors_load_from_hub_from_safetensors_pt(self): @@ -265,7 +268,27 @@ class FlaxModelUtilsTest(unittest.TestCase): safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors") self.assertTrue(check_models_equal(flax_model, safetensors_model)) + @require_safetensors @require_torch + @is_pt_flax_cross_test + def test_safetensors_load_from_hub_from_safetensors_pt_bf16(self): + """ + This test checks that we can load safetensors from a checkpoint that only has those on the Hub. + saved in the "pt" format. + """ + import torch + + model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors") + model.to(torch.bfloat16) + + with tempfile.TemporaryDirectory() as tmp: + model.save_pretrained(tmp) + flax_model = FlaxBertModel.from_pretrained(tmp) + + # Can load from the PyTorch-formatted checkpoint + safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16") + self.assertTrue(check_models_equal(flax_model, safetensors_model)) + @require_safetensors @is_pt_flax_cross_test def test_safetensors_load_from_local_from_safetensors_pt(self): @@ -284,39 +307,6 @@ class FlaxModelUtilsTest(unittest.TestCase): self.assertTrue(check_models_equal(flax_model, safetensors_model)) - @require_safetensors - def test_safetensors_load_from_hub_from_safetensors_pt_without_torch_installed(self): - """ - This test checks that we cannot load safetensors from a checkpoint that only has safetensors - saved in the "pt" format if torch isn't installed. - """ - if is_torch_available(): - # This test verifies that a correct error message is shown when loading from a pt safetensors - # PyTorch shouldn't be installed for this to work correctly. - return - - # Cannot load from the PyTorch-formatted checkpoint without PyTorch installed - with self.assertRaises(ModuleNotFoundError): - _ = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors") - - @require_safetensors - def test_safetensors_load_from_local_from_safetensors_pt_without_torch_installed(self): - """ - This test checks that we cannot load safetensors from a checkpoint that only has safetensors - saved in the "pt" format if torch isn't installed. - """ - if is_torch_available(): - # This test verifies that a correct error message is shown when loading from a pt safetensors - # PyTorch shouldn't be installed for this to work correctly. - return - - with tempfile.TemporaryDirectory() as tmp: - location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp) - - # Cannot load from the PyTorch-formatted checkpoint without PyTorch installed - with self.assertRaises(ModuleNotFoundError): - _ = FlaxBertModel.from_pretrained(location) - @require_safetensors def test_safetensors_load_from_hub_msgpack_before_safetensors(self): """ @@ -347,6 +337,7 @@ class FlaxModelUtilsTest(unittest.TestCase): @require_safetensors @require_torch + @is_pt_flax_cross_test def test_safetensors_flax_from_torch(self): hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only") @@ -372,3 +363,41 @@ class FlaxModelUtilsTest(unittest.TestCase): # This should not raise even if there are two types of sharded weights # This should discard the safetensors weights in favor of the msgpack sharded weights FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded") + + @require_safetensors + def test_safetensors_from_pt_bf16(self): + # This should not raise; should be able to load bf16-serialized torch safetensors without issue + # and without torch. + logger = logging.get_logger("transformers.modeling_flax_utils") + + with CaptureLogger(logger) as cl: + FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16") + + self.assertTrue( + "Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint" + in cl.out + ) + + @require_torch + @require_safetensors + @is_pt_flax_cross_test + def test_from_pt_bf16(self): + model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only") + model.to(torch.bfloat16) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=False) + + logger = logging.get_logger("transformers.modeling_flax_utils") + + with CaptureLogger(logger) as cl: + new_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16") + + self.assertTrue( + "Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint" + in cl.out + ) + + flat_params_1 = flatten_dict(new_model.params) + for value in flat_params_1.values(): + self.assertEqual(value.dtype, "bfloat16")