Safetensors serialization by default (#27064)
* Safetensors serialization by default * First pass on the tests * Second pass on the tests * Third pass on the tests * Fix TF weight loading from TF-format safetensors * Specific encoder-decoder fixes for weight crossloading * Add VisionEncoderDecoder fixes for TF too * Change filename test for pt-to-tf * One missing fix for TFVisionEncoderDecoder * Fix the other crossload test * Support for flax + updated tests * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Sanchit's comments * Sanchit's comments 2 * Nico's comments * Fix tests * cleanup * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -525,7 +525,7 @@ class TFEncoderDecoderMixin:
|
||||
# PT -> TF
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname)
|
||||
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
@@ -542,7 +542,7 @@ class TFEncoderDecoderMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname)
|
||||
|
||||
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
@@ -560,7 +560,8 @@ class TFEncoderDecoderMixin:
|
||||
tf_model(**tf_inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tf_model.save_pretrained(tmpdirname)
|
||||
# TODO Matt: PT doesn't support loading TF safetensors - remove the arg and from_tf=True when it does
|
||||
tf_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
pt_model = EncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)
|
||||
|
||||
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
|
||||
@@ -1129,9 +1130,7 @@ class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname_1, tempfile.TemporaryDirectory() as tmp_dirname_2:
|
||||
encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1)
|
||||
encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2)
|
||||
encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True
|
||||
)
|
||||
encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained(tmp_dirname_1, tmp_dirname_2)
|
||||
|
||||
logits_tf = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
|
||||
|
||||
@@ -1150,7 +1149,7 @@ class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
||||
|
||||
# TensorFlow => PyTorch
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
encoder_decoder_tf.save_pretrained(tmp_dirname)
|
||||
encoder_decoder_tf.save_pretrained(tmp_dirname, safe_serialization=False)
|
||||
encoder_decoder_pt = EncoderDecoderModel.from_pretrained(tmp_dirname, from_tf=True)
|
||||
|
||||
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
|
||||
|
||||
Reference in New Issue
Block a user