Fix weight loading issue (#14016)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
|||||||
""" Classes to support TF Encoder-Decoder architectures """
|
""" Classes to support TF Encoder-Decoder architectures """
|
||||||
|
|
||||||
|
|
||||||
|
import tempfile
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -254,6 +255,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
>>> # This is only for copying some specific attributes of this particular model.
|
>>> # This is only for copying some specific attributes of this particular model.
|
||||||
>>> model.config = _model.config
|
>>> model.config = _model.config
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import TFEncoderDecoderModel
|
||||||
|
>>> model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from_pt = kwargs.pop("from_pt", False)
|
from_pt = kwargs.pop("from_pt", False)
|
||||||
@@ -369,6 +375,14 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
|
kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
|
||||||
encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
||||||
|
|
||||||
|
# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
|
||||||
|
if kwargs_encoder.get("from_pt", None):
|
||||||
|
del kwargs_encoder["from_pt"]
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
|
encoder.save_pretrained(tmp_dirname)
|
||||||
|
del encoder
|
||||||
|
encoder = TFAutoModel.from_pretrained(tmp_dirname, *model_args, **kwargs_encoder)
|
||||||
|
|
||||||
decoder = kwargs_decoder.pop("model", None)
|
decoder = kwargs_decoder.pop("model", None)
|
||||||
if decoder is None:
|
if decoder is None:
|
||||||
if decoder_pretrained_model_name_or_path is None:
|
if decoder_pretrained_model_name_or_path is None:
|
||||||
@@ -397,6 +411,14 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
|
kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
|
||||||
decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||||
|
|
||||||
|
# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
|
||||||
|
if kwargs_decoder.get("from_pt", None):
|
||||||
|
del kwargs_decoder["from_pt"]
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
|
decoder.save_pretrained(tmp_dirname)
|
||||||
|
del decoder
|
||||||
|
decoder = TFAutoModelForCausalLM.from_pretrained(tmp_dirname, **kwargs_decoder)
|
||||||
|
|
||||||
# Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
|
# Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
|
||||||
if encoder.name != "encoder":
|
if encoder.name != "encoder":
|
||||||
raise ValueError("encoder model must be created with the name `encoder`.")
|
raise ValueError("encoder model must be created with the name `encoder`.")
|
||||||
|
|||||||
@@ -457,6 +457,14 @@ class TFBertEncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
|
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
|
||||||
|
|
||||||
|
# Test with the TF checkpoint
|
||||||
|
model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")
|
||||||
|
|
||||||
|
output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
|
||||||
|
summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
|
class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
|
||||||
@@ -785,6 +793,16 @@ class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
|||||||
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
|
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
|
||||||
self.assertAlmostEqual(max_diff, 0.0, places=3)
|
self.assertAlmostEqual(max_diff, 0.0, places=3)
|
||||||
|
|
||||||
|
# Make sure `from_pretrained` following `save_pretrained` work and give the same result
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
|
encoder_decoder_tf.save_pretrained(tmp_dirname)
|
||||||
|
encoder_decoder_tf = TFEncoderDecoderModel.from_pretrained(tmp_dirname)
|
||||||
|
|
||||||
|
logits_tf_2 = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
|
||||||
|
|
||||||
|
max_diff = np.max(np.abs(logits_tf_2.numpy() - logits_tf.numpy()))
|
||||||
|
self.assertAlmostEqual(max_diff, 0.0, places=3)
|
||||||
|
|
||||||
# TensorFlow => PyTorch
|
# TensorFlow => PyTorch
|
||||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
encoder_decoder_tf.save_pretrained(tmp_dirname)
|
encoder_decoder_tf.save_pretrained(tmp_dirname)
|
||||||
|
|||||||
Reference in New Issue
Block a user