PT <-> TF for composite models (#19732)
* First step of PT->TF for composite models * Update the tests * For VisionEncoderDecoderModel * Fix * Fix * Add comment * Fix * clean up import * Save memory * For (TF)EncoderDecoderModel * For (TF)EncoderDecoderModel Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -523,15 +523,9 @@ class TFEncoderDecoderMixin:
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
# PT -> TF
|
||||
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||
|
||||
pt_model.encoder.save_pretrained(encoder_tmp_dirname)
|
||||
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
|
||||
tf_model_loaded = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
|
||||
)
|
||||
# This is only for copying some specific attributes of this particular model.
|
||||
tf_model_loaded.config = pt_model.config
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
@@ -546,15 +540,9 @@ class TFEncoderDecoderMixin:
|
||||
|
||||
pt_model = EncoderDecoderModel(encoder_decoder_config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||
|
||||
pt_model.encoder.save_pretrained(encoder_tmp_dirname)
|
||||
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
|
||||
tf_model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
|
||||
)
|
||||
# This is only for copying some specific attributes of this particular model.
|
||||
tf_model.config = pt_model.config
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
@@ -567,33 +555,13 @@ class TFEncoderDecoderMixin:
|
||||
# TODO: A generalizable way to determine this attribute
|
||||
encoder_decoder_config.output_attentions = True
|
||||
|
||||
# Using `_tf_model`, the test will fail, because the weights of `_tf_model` get extended before saving
|
||||
# the encoder/decoder models.
|
||||
# There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see
|
||||
# https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245
|
||||
# (the change in `src/transformers/modeling_tf_utils.py`)
|
||||
_tf_model = TFEncoderDecoderModel(encoder_decoder_config)
|
||||
# Make sure model is built
|
||||
_tf_model(**tf_inputs_dict)
|
||||
tf_model = TFEncoderDecoderModel(encoder_decoder_config)
|
||||
# Make sure model is built before saving
|
||||
tf_model(**tf_inputs_dict)
|
||||
|
||||
# Using `tf_model` to pass the test.
|
||||
encoder = _tf_model.encoder.__class__(encoder_decoder_config.encoder)
|
||||
decoder = _tf_model.decoder.__class__(encoder_decoder_config.decoder)
|
||||
# Make sure models are built
|
||||
encoder(encoder.dummy_inputs)
|
||||
decoder(decoder.dummy_inputs)
|
||||
tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||
tf_model.config = encoder_decoder_config
|
||||
|
||||
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||
|
||||
tf_model.encoder.save_pretrained(encoder_tmp_dirname)
|
||||
tf_model.decoder.save_pretrained(decoder_tmp_dirname)
|
||||
pt_model = EncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_tf=True, decoder_from_tf=True
|
||||
)
|
||||
# This is only for copying some specific attributes of this particular model.
|
||||
pt_model.config = tf_model.config
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tf_model.save_pretrained(tmpdirname)
|
||||
pt_model = EncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)
|
||||
|
||||
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
@@ -696,20 +664,11 @@ class TFEncoderDecoderMixin:
|
||||
self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict_with_labels)
|
||||
self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict_with_labels)
|
||||
|
||||
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
|
||||
# which randomly initialize `enc_to_dec_proj`.
|
||||
# check `enc_to_dec_proj` work as expected
|
||||
# decoder_config.hidden_size = decoder_config.hidden_size * 2
|
||||
# self.assertTrue(config.hidden_size != decoder_config.hidden_size)
|
||||
# self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict)
|
||||
# self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict)
|
||||
|
||||
# Let's just check `enc_to_dec_proj` can run for now
|
||||
decoder_config.hidden_size = decoder_config.hidden_size * 2
|
||||
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
|
||||
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
model = TFEncoderDecoderModel(encoder_decoder_config)
|
||||
model(tf_inputs_dict)
|
||||
self.check_pt_to_tf_equivalence(config, decoder_config, tf_inputs_dict)
|
||||
self.check_tf_to_pt_equivalence(config, decoder_config, tf_inputs_dict)
|
||||
|
||||
def test_model_save_load_from_pretrained(self):
|
||||
model_2 = self.get_pretrained_model()
|
||||
|
||||
Reference in New Issue
Block a user