Safetensors serialization by default (#27064)

* Safetensors serialization by default

* First pass on the tests

* Second pass on the tests

* Third pass on the tests

* Fix TF weight loading from TF-format safetensors

* Specific encoder-decoder fixes for weight crossloading

* Add VisionEncoderDecoder fixes for TF too

* Change filename test for pt-to-tf

* One missing fix for TFVisionEncoderDecoder

* Fix the other crossload test

* Support for flax + updated tests

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Sanchit's comments

* Sanchit's comments 2

* Nico's comments

* Fix tests

* cleanup

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: Matt <rocketknight1@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Lysandre Debut
2023-10-31 19:16:49 +01:00
committed by GitHub
parent 25e6e9418c
commit 113ebf80ac
20 changed files with 433 additions and 137 deletions

View File

@@ -211,6 +211,8 @@ class TFAutoModelTest(unittest.TestCase):
config = copy.deepcopy(model.config)
config.architectures = ["FunnelBaseModel"]
model = TFAutoModel.from_config(config)
model.build()
self.assertIsInstance(model, TFFunnelBaseModel)
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -245,7 +247,10 @@ class TFAutoModelTest(unittest.TestCase):
# Now that the config is registered, it can be used as any other config with the auto-API
tiny_config = BertModelTester(self).get_config()
config = NewModelConfig(**tiny_config.to_dict())
model = auto_class.from_config(config)
model.build()
self.assertIsInstance(model, TFNewModel)
with tempfile.TemporaryDirectory() as tmp_dir:

View File

@@ -525,7 +525,7 @@ class TFEncoderDecoderMixin:
# PT -> TF
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname)
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
@@ -542,7 +542,7 @@ class TFEncoderDecoderMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname)
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
@@ -560,7 +560,8 @@ class TFEncoderDecoderMixin:
tf_model(**tf_inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
tf_model.save_pretrained(tmpdirname)
# TODO Matt: PT doesn't support loading TF safetensors - remove the arg and from_tf=True when it does
tf_model.save_pretrained(tmpdirname, safe_serialization=False)
pt_model = EncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
@@ -1129,9 +1130,7 @@ class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dirname_1, tempfile.TemporaryDirectory() as tmp_dirname_2:
encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1)
encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2)
encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True
)
encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained(tmp_dirname_1, tmp_dirname_2)
logits_tf = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
@@ -1150,7 +1149,7 @@ class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase):
# TensorFlow => PyTorch
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder_decoder_tf.save_pretrained(tmp_dirname)
encoder_decoder_tf.save_pretrained(tmp_dirname, safe_serialization=False)
encoder_decoder_pt = EncoderDecoderModel.from_pretrained(tmp_dirname, from_tf=True)
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))

View File

@@ -458,7 +458,7 @@ class TFVisionEncoderDecoderMixin:
# PT -> TF
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname)
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
@@ -473,7 +473,7 @@ class TFVisionEncoderDecoderMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname)
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
@@ -489,7 +489,7 @@ class TFVisionEncoderDecoderMixin:
tf_model(**tf_inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
tf_model.save_pretrained(tmpdirname)
tf_model.save_pretrained(tmpdirname, safe_serialization=False)
pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
@@ -803,7 +803,7 @@ class TFVisionEncoderDecoderModelSaveLoadTests(unittest.TestCase):
encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1)
encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2)
encoder_decoder_tf = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True
tmp_dirname_1, tmp_dirname_2
)
logits_tf = encoder_decoder_tf(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits
@@ -814,7 +814,7 @@ class TFVisionEncoderDecoderModelSaveLoadTests(unittest.TestCase):
# Make sure `from_pretrained` following `save_pretrained` work and give the same result
# (See https://github.com/huggingface/transformers/pull/14016)
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder_decoder_tf.save_pretrained(tmp_dirname)
encoder_decoder_tf.save_pretrained(tmp_dirname, safe_serialization=False)
encoder_decoder_tf = TFVisionEncoderDecoderModel.from_pretrained(tmp_dirname)
logits_tf_2 = encoder_decoder_tf(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits