From 47500b1d72021b4820a531fca3f6ab7e5a517106 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 7 Dec 2023 14:28:53 +0000 Subject: [PATCH] Fix TF loading PT safetensors when weights are tied (#27490) * Un-skip tests * Add aliasing support to tf_to_pt_weight_rename * Refactor tf-to-pt weight rename for simplicity * Patch mobilebert * Let us pray that the transfo-xl one works * Add XGLM rename * Expand the test to see if we can get more models to break * Expand the test to see if we can get more models to break * Fix MPNet (it was actually an unrelated bug) * Fix MPNet (it was actually an unrelated bug) * Add speech2text fix * Update src/transformers/modeling_tf_pytorch_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/mobilebert/modeling_tf_mobilebert.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update to always return a tuple from tf_to_pt_weight_rename * reformat * Add a couple of missing tuples * Remove the extra test for tie_word_embeddings since it didn't cause any unexpected failures anyway * Revert changes to modeling_tf_mpnet.py * Skip MPNet test and add explanation * Add weight link for BART * Add TODO to clean this up a bit --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/modeling_tf_pytorch_utils.py | 9 +++- src/transformers/modeling_tf_utils.py | 6 +++ .../models/bart/modeling_tf_bart.py | 6 +++ .../transfo_xl/modeling_tf_transfo_xl.py | 15 +++++++ .../modeling_tf_encoder_decoder.py | 28 +++--------- .../mobilebert/modeling_tf_mobilebert.py | 12 +++++ .../modeling_tf_speech_to_text.py | 6 +++ .../modeling_tf_vision_encoder_decoder.py | 45 +++---------------- .../modeling_tf_vision_text_dual_encoder.py | 38 +++++++--------- .../models/xglm/modeling_tf_xglm.py | 6 +++ .../mobilebert/test_modeling_mobilebert.py | 4 -- tests/models/mpnet/test_modeling_mpnet.py | 2 +- .../test_modeling_speech_to_text_2.py | 4 -- tests/models/xglm/test_modeling_xglm.py | 4 -- 14 files changed, 87 insertions(+), 98 deletions(-) diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index d45b95fa5b..c599b795bf 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -318,7 +318,14 @@ def load_pytorch_state_dict_in_tf2_model( name_scope=_prefix, ) if tf_to_pt_weight_rename is not None: - name = tf_to_pt_weight_rename(name) + aliases = tf_to_pt_weight_rename(name) # Is a tuple to account for possible name aliasing + for alias in aliases: # The aliases are in priority order, take the first one that matches + if alias in tf_keys_to_pt_keys: + name = alias + break + else: + # If none of the aliases match, just use the first one (it'll be reported as missing) + name = aliases[0] # Find associated numpy array in pytorch model state dict if name not in tf_keys_to_pt_keys: diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index bfd928a901..00fe790252 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -2892,6 +2892,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu # Instantiate model. model = cls(config, *model_args, **model_kwargs) + if tf_to_pt_weight_rename is None and hasattr(model, "tf_to_pt_weight_rename"): + # TODO Matt: This is a temporary workaround to allow weight renaming, but requires a method + # to be defined for each class that requires a rename. We can probably just have a class-level + # dict and a single top-level method or something and cut down a lot of boilerplate code + tf_to_pt_weight_rename = model.tf_to_pt_weight_rename + if from_pt: from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 497dad4249..b04e3ed997 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -494,6 +494,12 @@ class TFBartPretrainedModel(TFPreTrainedModel): dummy_inputs["decoder_input_ids"] = dummy_inputs["decoder_input_ids"] * 2 return dummy_inputs + def tf_to_pt_weight_rename(self, tf_weight): + if tf_weight == "model.shared.weight": + return tf_weight, "model.decoder.embed_tokens.weight" + else: + return (tf_weight,) + BART_START_DOCSTRING = r""" This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the diff --git a/src/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl.py b/src/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl.py index 45a4ea56fd..9ae32f8ceb 100644 --- a/src/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl.py +++ b/src/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl.py @@ -987,6 +987,21 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): return inputs + # Adapted from the torch tie_weights function + def tf_to_pt_weight_rename(self, tf_weight): + if self.config.tie_word_embeddings and "crit.out_layers" in tf_weight: + return tf_weight, tf_weight.replace("crit.out_layers", "transformer.word_emb.emb_layers") + elif self.config.tie_projs and "crit.out_projs" in tf_weight: + for i, tie_proj in enumerate(self.config.tie_projs): + if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed: + # self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0] + return tf_weight, tf_weight.replace(f"crit.out_projs.{i}", "transformer.word_emb.emb_projs.0") + elif tie_proj and self.config.div_val != 1: + # self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i] + return tf_weight, tf_weight.replace("crit.out_projs", "transformer.word_emb.emb_projs") + else: + return (tf_weight,) + @add_start_docstrings( """ diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index 5b4fc5884c..afd8963359 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -291,16 +291,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): def set_output_embeddings(self, new_embeddings): return self.decoder.set_output_embeddings(new_embeddings) - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Example: - - ```python - >>> from transformers import TFEncoderDecoderModel - - >>> model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16") - ```""" + def tf_to_pt_weight_rename(self, tf_weight): # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal. # However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption @@ -311,18 +302,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): # often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file. # Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it # or not. - - config = AutoConfig.from_pretrained(pretrained_model_name_or_path) - encoder_model_type = config.encoder.model_type - - def tf_to_pt_weight_rename(tf_weight): - if "encoder" in tf_weight and "decoder" not in tf_weight: - return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight) - else: - return tf_weight - - kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + encoder_model_type = self.config.encoder.model_type + if "encoder" in tf_weight and "decoder" not in tf_weight: + return (re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight),) + else: + return (tf_weight,) @classmethod def from_encoder_decoder_pretrained( diff --git a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py index bc508a4798..ecf9b65c2c 100644 --- a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py @@ -1088,6 +1088,12 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel, TFMobileBertPreTra attentions=outputs.attentions, ) + def tf_to_pt_weight_rename(self, tf_weight): + if tf_weight == "cls.predictions.decoder.weight": + return tf_weight, "mobilebert.embeddings.word_embeddings.weight" + else: + return (tf_weight,) + @add_start_docstrings("""MobileBert Model with a `language modeling` head on top.""", MOBILEBERT_START_DOCSTRING) class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss): @@ -1168,6 +1174,12 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel attentions=outputs.attentions, ) + def tf_to_pt_weight_rename(self, tf_weight): + if tf_weight == "cls.predictions.decoder.weight": + return tf_weight, "mobilebert.embeddings.word_embeddings.weight" + else: + return (tf_weight,) + class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer): def __init__(self, config, **kwargs): diff --git a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py index 026d2241b4..4c6d2ffcb3 100755 --- a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py @@ -1460,3 +1460,9 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } + + def tf_to_pt_weight_rename(self, tf_weight): + if tf_weight == "lm_head.weight": + return tf_weight, "model.decoder.embed_tokens.weight" + else: + return (tf_weight,) diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index 65f55d55e8..395d02bf0b 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -290,33 +290,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos def set_output_embeddings(self, new_embeddings): return self.decoder.set_output_embeddings(new_embeddings) - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Example: - - ```python - >>> from transformers import TFVisionEncoderDecoderModel, AutoImageProcessor, AutoTokenizer - >>> from PIL import Image - >>> import requests - - >>> image_processor = AutoImageProcessor.from_pretrained("ydshieh/vit-gpt2-coco-en") - >>> decoder_tokenizer = AutoTokenizer.from_pretrained("ydshieh/vit-gpt2-coco-en") - >>> model = TFVisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> img = Image.open(requests.get(url, stream=True).raw) - >>> pixel_values = image_processor(images=img, return_tensors="tf").pixel_values # Batch size 1 - - >>> output_ids = model.generate( - ... pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True - ... ).sequences - - >>> preds = decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True) - >>> preds = [pred.strip() for pred in preds] - - >>> assert preds == ["a cat laying on top of a couch next to another cat"] - ```""" + def tf_to_pt_weight_rename(self, tf_weight): # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal. # However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption @@ -327,18 +301,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos # often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file. # Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it # or not. - - config = AutoConfig.from_pretrained(pretrained_model_name_or_path) - encoder_model_type = config.encoder.model_type - - def tf_to_pt_weight_rename(tf_weight): - if "encoder" in tf_weight and "decoder" not in tf_weight: - return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight) - else: - return tf_weight - - kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + encoder_model_type = self.config.encoder.model_type + if "encoder" in tf_weight and "decoder" not in tf_weight: + return (re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight),) + else: + return (tf_weight,) @classmethod def from_encoder_decoder_pretrained( diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py index 34349c8661..d0e91640f6 100644 --- a/src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py +++ b/src/transformers/models/vision_text_dual_encoder/modeling_tf_vision_text_dual_encoder.py @@ -227,32 +227,24 @@ class TFVisionTextDualEncoderModel(TFPreTrainedModel): self.logit_scale = self.add_weight(shape=(1,), initializer=initializer, name="logit_scale") super().build(input_shape) - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + def tf_to_pt_weight_rename(self, tf_weight): # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal. # However, the name of that extra layer is the name of the MainLayer in the base model. - - if kwargs.get("from_pt", False): - - def tf_to_pt_weight_rename(tf_weight): - if "vision_model" in tf_weight: - if tf_weight.count("vision_model") == 1: - return re.sub(r"vision_model\..*?\.", "vision_model.", tf_weight) - elif tf_weight.count("vision_model") == 2: - return re.sub(r"vision_model\..*?\.vision_model", "vision_model.vision_model", tf_weight) - else: - raise ValueError( - f"Unexpected weight name {tf_weight}. Please file an issue on the" - " Transformers repo to let us know about this error!" - ) - elif "text_model" in tf_weight: - return re.sub(r"text_model\..*?\.", "text_model.", tf_weight) - else: - return tf_weight - - kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + if "vision_model" in tf_weight: + if tf_weight.count("vision_model") == 1: + return re.sub(r"vision_model\..*?\.", "vision_model.", tf_weight) + elif tf_weight.count("vision_model") == 2: + return re.sub(r"vision_model\..*?\.vision_model", "vision_model.vision_model", tf_weight) + else: + raise ValueError( + f"Unexpected weight name {tf_weight}. Please file an issue on the" + " Transformers repo to let us know about this error!" + ) + elif "text_model" in tf_weight: + return re.sub(r"text_model\..*?\.", "text_model.", tf_weight) + else: + return (tf_weight,) @add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_TEXT_INPUTS_DOCSTRING) def get_text_features( diff --git a/src/transformers/models/xglm/modeling_tf_xglm.py b/src/transformers/models/xglm/modeling_tf_xglm.py index e2890edeb6..05f87eb5d3 100644 --- a/src/transformers/models/xglm/modeling_tf_xglm.py +++ b/src/transformers/models/xglm/modeling_tf_xglm.py @@ -924,3 +924,9 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss): attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) + + def tf_to_pt_weight_rename(self, tf_weight): + if tf_weight == "lm_head.weight": + return tf_weight, "model.embed_tokens.weight" + else: + return (tf_weight,) diff --git a/tests/models/mobilebert/test_modeling_mobilebert.py b/tests/models/mobilebert/test_modeling_mobilebert.py index a914ce578d..e4ebca4b6e 100644 --- a/tests/models/mobilebert/test_modeling_mobilebert.py +++ b/tests/models/mobilebert/test_modeling_mobilebert.py @@ -302,10 +302,6 @@ class MobileBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa def test_resize_tokens_embeddings(self): super().test_resize_tokens_embeddings() - @unittest.skip("This test is currently broken because of safetensors.") - def test_tf_from_pt_safetensors(self): - pass - def setUp(self): self.model_tester = MobileBertModelTester(self) self.config_tester = ConfigTester(self, config_class=MobileBertConfig, hidden_size=37) diff --git a/tests/models/mpnet/test_modeling_mpnet.py b/tests/models/mpnet/test_modeling_mpnet.py index 52d8d1f8b4..10c0c164d1 100644 --- a/tests/models/mpnet/test_modeling_mpnet.py +++ b/tests/models/mpnet/test_modeling_mpnet.py @@ -246,7 +246,7 @@ class MPNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_mpnet_for_question_answering(*config_and_inputs) - @unittest.skip("This isn't passing but should, seems like a misconfiguration of tied weights.") + @unittest.skip("TFMPNet adds poolers to all models, unlike the PT model class.") def test_tf_from_pt_safetensors(self): return diff --git a/tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py b/tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py index b2220a9e74..cbb449c6e7 100644 --- a/tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py +++ b/tests/models/speech_to_text_2/test_modeling_speech_to_text_2.py @@ -196,10 +196,6 @@ class Speech2Text2StandaloneDecoderModelTest( def test_inputs_embeds(self): pass - @unittest.skip("This test is currently broken because of safetensors.") - def test_tf_from_pt_safetensors(self): - pass - # speech2text2 has no base model def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/xglm/test_modeling_xglm.py b/tests/models/xglm/test_modeling_xglm.py index 457317f078..e482b1b384 100644 --- a/tests/models/xglm/test_modeling_xglm.py +++ b/tests/models/xglm/test_modeling_xglm.py @@ -357,10 +357,6 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin def test_model_parallelism(self): super().test_model_parallelism() - @unittest.skip("This test is currently broken because of safetensors.") - def test_tf_from_pt_safetensors(self): - pass - @require_torch class XGLMModelLanguageGenerationTest(unittest.TestCase):