Fix from_pt flag when loading with safetensors (#27394)
* Fix * Tests * Fix
This commit is contained in:
@@ -166,6 +166,7 @@ def load_pytorch_checkpoint_in_tf2_model(
|
|||||||
try:
|
try:
|
||||||
import tensorflow as tf # noqa: F401
|
import tensorflow as tf # noqa: F401
|
||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
|
from safetensors.torch import load_file as safe_load_file # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||||
@@ -182,7 +183,12 @@ def load_pytorch_checkpoint_in_tf2_model(
|
|||||||
for path in pytorch_checkpoint_path:
|
for path in pytorch_checkpoint_path:
|
||||||
pt_path = os.path.abspath(path)
|
pt_path = os.path.abspath(path)
|
||||||
logger.info(f"Loading PyTorch weights from {pt_path}")
|
logger.info(f"Loading PyTorch weights from {pt_path}")
|
||||||
pt_state_dict.update(torch.load(pt_path, map_location="cpu"))
|
if pt_path.endswith(".safetensors"):
|
||||||
|
state_dict = safe_load_file(pt_path)
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(pt_path, map_location="cpu")
|
||||||
|
|
||||||
|
pt_state_dict.update(state_dict)
|
||||||
|
|
||||||
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters")
|
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters")
|
||||||
|
|
||||||
|
|||||||
@@ -246,6 +246,10 @@ class MPNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_mpnet_for_question_answering(*config_and_inputs)
|
self.model_tester.create_and_check_mpnet_for_question_answering(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skip("This isn't passing but should, seems like a misconfiguration of tied weights.")
|
||||||
|
def test_tf_from_pt_safetensors(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class MPNetModelIntegrationTest(unittest.TestCase):
|
class MPNetModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -824,6 +824,12 @@ class Wav2Vec2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||||
self.clear_torch_jit_class_registry()
|
self.clear_torch_jit_class_registry()
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"Need to investigate why config.do_stable_layer_norm is set to False here when it doesn't seem to be supported"
|
||||||
|
)
|
||||||
|
def test_flax_from_pt_safetensors(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ if is_tf_available():
|
|||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
from tests.test_modeling_flax_utils import check_models_equal
|
||||||
from transformers.modeling_flax_pytorch_utils import (
|
from transformers.modeling_flax_pytorch_utils import (
|
||||||
convert_pytorch_state_dict_to_flax,
|
convert_pytorch_state_dict_to_flax,
|
||||||
load_flax_weights_in_pytorch_model,
|
load_flax_weights_in_pytorch_model,
|
||||||
@@ -3219,6 +3220,55 @@ class ModelTesterMixin:
|
|||||||
# with attention mask
|
# with attention mask
|
||||||
_ = model(dummy_input, attention_mask=dummy_attention_mask)
|
_ = model(dummy_input, attention_mask=dummy_attention_mask)
|
||||||
|
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_tf_from_pt_safetensors(self):
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
|
||||||
|
if not hasattr(transformers, tf_model_class_name):
|
||||||
|
# transformers does not have this model in TF version yet
|
||||||
|
return
|
||||||
|
|
||||||
|
tf_model_class = getattr(transformers, tf_model_class_name)
|
||||||
|
|
||||||
|
pt_model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
pt_model.save_pretrained(tmpdirname, safe_serialization=True)
|
||||||
|
tf_model_1 = tf_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
|
pt_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||||
|
tf_model_2 = tf_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
for p1, p2 in zip(tf_model_1.weights, tf_model_2.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_flax_from_pt_safetensors(self):
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
flax_model_class_name = "Flax" + model_class.__name__ # Add the "Flax at the beginning
|
||||||
|
if not hasattr(transformers, flax_model_class_name):
|
||||||
|
# transformers does not have this model in Flax version yet
|
||||||
|
return
|
||||||
|
|
||||||
|
flax_model_class = getattr(transformers, flax_model_class_name)
|
||||||
|
|
||||||
|
pt_model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
pt_model.save_pretrained(tmpdirname, safe_serialization=True)
|
||||||
|
flax_model_1 = flax_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
|
pt_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||||
|
flax_model_2 = flax_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
self.assertTrue(check_models_equal(flax_model_1, flax_model_2))
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user