Fix from_pt flag when loading with safetensors (#27394)
* Fix * Tests * Fix
This commit is contained in:
@@ -105,6 +105,7 @@ if is_tf_available():
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
from tests.test_modeling_flax_utils import check_models_equal
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
@@ -3219,6 +3220,55 @@ class ModelTesterMixin:
|
||||
# with attention mask
|
||||
_ = model(dummy_input, attention_mask=dummy_attention_mask)
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_tf_from_pt_safetensors(self):
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
|
||||
if not hasattr(transformers, tf_model_class_name):
|
||||
# transformers does not have this model in TF version yet
|
||||
return
|
||||
|
||||
tf_model_class = getattr(transformers, tf_model_class_name)
|
||||
|
||||
pt_model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname, safe_serialization=True)
|
||||
tf_model_1 = tf_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
pt_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
tf_model_2 = tf_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(tf_model_1.weights, tf_model_2.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_flax_from_pt_safetensors(self):
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
flax_model_class_name = "Flax" + model_class.__name__ # Add the "Flax at the beginning
|
||||
if not hasattr(transformers, flax_model_class_name):
|
||||
# transformers does not have this model in Flax version yet
|
||||
return
|
||||
|
||||
flax_model_class = getattr(transformers, flax_model_class_name)
|
||||
|
||||
pt_model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname, safe_serialization=True)
|
||||
flax_model_1 = flax_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
pt_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
flax_model_2 = flax_model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
# Check models are equal
|
||||
self.assertTrue(check_models_equal(flax_model_1, flax_model_2))
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user