Fix TF loading PT safetensors when weights are tied (#27490)
* Un-skip tests * Add aliasing support to tf_to_pt_weight_rename * Refactor tf-to-pt weight rename for simplicity * Patch mobilebert * Let us pray that the transfo-xl one works * Add XGLM rename * Expand the test to see if we can get more models to break * Expand the test to see if we can get more models to break * Fix MPNet (it was actually an unrelated bug) * Fix MPNet (it was actually an unrelated bug) * Add speech2text fix * Update src/transformers/modeling_tf_pytorch_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/mobilebert/modeling_tf_mobilebert.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update to always return a tuple from tf_to_pt_weight_rename * reformat * Add a couple of missing tuples * Remove the extra test for tie_word_embeddings since it didn't cause any unexpected failures anyway * Revert changes to modeling_tf_mpnet.py * Skip MPNet test and add explanation * Add weight link for BART * Add TODO to clean this up a bit --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -318,7 +318,14 @@ def load_pytorch_state_dict_in_tf2_model(
|
||||
name_scope=_prefix,
|
||||
)
|
||||
if tf_to_pt_weight_rename is not None:
|
||||
name = tf_to_pt_weight_rename(name)
|
||||
aliases = tf_to_pt_weight_rename(name) # Is a tuple to account for possible name aliasing
|
||||
for alias in aliases: # The aliases are in priority order, take the first one that matches
|
||||
if alias in tf_keys_to_pt_keys:
|
||||
name = alias
|
||||
break
|
||||
else:
|
||||
# If none of the aliases match, just use the first one (it'll be reported as missing)
|
||||
name = aliases[0]
|
||||
|
||||
# Find associated numpy array in pytorch model state dict
|
||||
if name not in tf_keys_to_pt_keys:
|
||||
|
||||
@@ -2892,6 +2892,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
# Instantiate model.
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if tf_to_pt_weight_rename is None and hasattr(model, "tf_to_pt_weight_rename"):
|
||||
# TODO Matt: This is a temporary workaround to allow weight renaming, but requires a method
|
||||
# to be defined for each class that requires a rename. We can probably just have a class-level
|
||||
# dict and a single top-level method or something and cut down a lot of boilerplate code
|
||||
tf_to_pt_weight_rename = model.tf_to_pt_weight_rename
|
||||
|
||||
if from_pt:
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
|
||||
@@ -494,6 +494,12 @@ class TFBartPretrainedModel(TFPreTrainedModel):
|
||||
dummy_inputs["decoder_input_ids"] = dummy_inputs["decoder_input_ids"] * 2
|
||||
return dummy_inputs
|
||||
|
||||
def tf_to_pt_weight_rename(self, tf_weight):
|
||||
if tf_weight == "model.shared.weight":
|
||||
return tf_weight, "model.decoder.embed_tokens.weight"
|
||||
else:
|
||||
return (tf_weight,)
|
||||
|
||||
|
||||
BART_START_DOCSTRING = r"""
|
||||
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
|
||||
@@ -987,6 +987,21 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
||||
|
||||
return inputs
|
||||
|
||||
# Adapted from the torch tie_weights function
|
||||
def tf_to_pt_weight_rename(self, tf_weight):
|
||||
if self.config.tie_word_embeddings and "crit.out_layers" in tf_weight:
|
||||
return tf_weight, tf_weight.replace("crit.out_layers", "transformer.word_emb.emb_layers")
|
||||
elif self.config.tie_projs and "crit.out_projs" in tf_weight:
|
||||
for i, tie_proj in enumerate(self.config.tie_projs):
|
||||
if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
|
||||
# self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
|
||||
return tf_weight, tf_weight.replace(f"crit.out_projs.{i}", "transformer.word_emb.emb_projs.0")
|
||||
elif tie_proj and self.config.div_val != 1:
|
||||
# self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
|
||||
return tf_weight, tf_weight.replace("crit.out_projs", "transformer.word_emb.emb_projs")
|
||||
else:
|
||||
return (tf_weight,)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -291,16 +291,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
return self.decoder.set_output_embeddings(new_embeddings)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import TFEncoderDecoderModel
|
||||
|
||||
>>> model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")
|
||||
```"""
|
||||
def tf_to_pt_weight_rename(self, tf_weight):
|
||||
# Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models
|
||||
# (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal.
|
||||
# However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption
|
||||
@@ -311,18 +302,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
# often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file.
|
||||
# Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it
|
||||
# or not.
|
||||
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
||||
encoder_model_type = config.encoder.model_type
|
||||
|
||||
def tf_to_pt_weight_rename(tf_weight):
|
||||
if "encoder" in tf_weight and "decoder" not in tf_weight:
|
||||
return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight)
|
||||
else:
|
||||
return tf_weight
|
||||
|
||||
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
|
||||
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
encoder_model_type = self.config.encoder.model_type
|
||||
if "encoder" in tf_weight and "decoder" not in tf_weight:
|
||||
return (re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight),)
|
||||
else:
|
||||
return (tf_weight,)
|
||||
|
||||
@classmethod
|
||||
def from_encoder_decoder_pretrained(
|
||||
|
||||
@@ -1088,6 +1088,12 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel, TFMobileBertPreTra
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def tf_to_pt_weight_rename(self, tf_weight):
|
||||
if tf_weight == "cls.predictions.decoder.weight":
|
||||
return tf_weight, "mobilebert.embeddings.word_embeddings.weight"
|
||||
else:
|
||||
return (tf_weight,)
|
||||
|
||||
|
||||
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top.""", MOBILEBERT_START_DOCSTRING)
|
||||
class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
@@ -1168,6 +1174,12 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def tf_to_pt_weight_rename(self, tf_weight):
|
||||
if tf_weight == "cls.predictions.decoder.weight":
|
||||
return tf_weight, "mobilebert.embeddings.word_embeddings.weight"
|
||||
else:
|
||||
return (tf_weight,)
|
||||
|
||||
|
||||
class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
|
||||
@@ -1460,3 +1460,9 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def tf_to_pt_weight_rename(self, tf_weight):
|
||||
if tf_weight == "lm_head.weight":
|
||||
return tf_weight, "model.decoder.embed_tokens.weight"
|
||||
else:
|
||||
return (tf_weight,)
|
||||
|
||||
@@ -290,33 +290,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
return self.decoder.set_output_embeddings(new_embeddings)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import TFVisionEncoderDecoderModel, AutoImageProcessor, AutoTokenizer
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained("ydshieh/vit-gpt2-coco-en")
|
||||
>>> decoder_tokenizer = AutoTokenizer.from_pretrained("ydshieh/vit-gpt2-coco-en")
|
||||
>>> model = TFVisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> img = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> pixel_values = image_processor(images=img, return_tensors="tf").pixel_values # Batch size 1
|
||||
|
||||
>>> output_ids = model.generate(
|
||||
... pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True
|
||||
... ).sequences
|
||||
|
||||
>>> preds = decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
>>> preds = [pred.strip() for pred in preds]
|
||||
|
||||
>>> assert preds == ["a cat laying on top of a couch next to another cat"]
|
||||
```"""
|
||||
def tf_to_pt_weight_rename(self, tf_weight):
|
||||
# Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models
|
||||
# (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal.
|
||||
# However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption
|
||||
@@ -327,18 +301,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
||||
# often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file.
|
||||
# Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it
|
||||
# or not.
|
||||
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
||||
encoder_model_type = config.encoder.model_type
|
||||
|
||||
def tf_to_pt_weight_rename(tf_weight):
|
||||
if "encoder" in tf_weight and "decoder" not in tf_weight:
|
||||
return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight)
|
||||
else:
|
||||
return tf_weight
|
||||
|
||||
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
|
||||
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
encoder_model_type = self.config.encoder.model_type
|
||||
if "encoder" in tf_weight and "decoder" not in tf_weight:
|
||||
return (re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight),)
|
||||
else:
|
||||
return (tf_weight,)
|
||||
|
||||
@classmethod
|
||||
def from_encoder_decoder_pretrained(
|
||||
|
||||
@@ -227,32 +227,24 @@ class TFVisionTextDualEncoderModel(TFPreTrainedModel):
|
||||
self.logit_scale = self.add_weight(shape=(1,), initializer=initializer, name="logit_scale")
|
||||
super().build(input_shape)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
def tf_to_pt_weight_rename(self, tf_weight):
|
||||
# Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models
|
||||
# (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal.
|
||||
# However, the name of that extra layer is the name of the MainLayer in the base model.
|
||||
|
||||
if kwargs.get("from_pt", False):
|
||||
|
||||
def tf_to_pt_weight_rename(tf_weight):
|
||||
if "vision_model" in tf_weight:
|
||||
if tf_weight.count("vision_model") == 1:
|
||||
return re.sub(r"vision_model\..*?\.", "vision_model.", tf_weight)
|
||||
elif tf_weight.count("vision_model") == 2:
|
||||
return re.sub(r"vision_model\..*?\.vision_model", "vision_model.vision_model", tf_weight)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected weight name {tf_weight}. Please file an issue on the"
|
||||
" Transformers repo to let us know about this error!"
|
||||
)
|
||||
elif "text_model" in tf_weight:
|
||||
return re.sub(r"text_model\..*?\.", "text_model.", tf_weight)
|
||||
else:
|
||||
return tf_weight
|
||||
|
||||
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
|
||||
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
if "vision_model" in tf_weight:
|
||||
if tf_weight.count("vision_model") == 1:
|
||||
return re.sub(r"vision_model\..*?\.", "vision_model.", tf_weight)
|
||||
elif tf_weight.count("vision_model") == 2:
|
||||
return re.sub(r"vision_model\..*?\.vision_model", "vision_model.vision_model", tf_weight)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected weight name {tf_weight}. Please file an issue on the"
|
||||
" Transformers repo to let us know about this error!"
|
||||
)
|
||||
elif "text_model" in tf_weight:
|
||||
return re.sub(r"text_model\..*?\.", "text_model.", tf_weight)
|
||||
else:
|
||||
return (tf_weight,)
|
||||
|
||||
@add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_TEXT_INPUTS_DOCSTRING)
|
||||
def get_text_features(
|
||||
|
||||
@@ -924,3 +924,9 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def tf_to_pt_weight_rename(self, tf_weight):
|
||||
if tf_weight == "lm_head.weight":
|
||||
return tf_weight, "model.embed_tokens.weight"
|
||||
else:
|
||||
return (tf_weight,)
|
||||
|
||||
@@ -302,10 +302,6 @@ class MobileBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
def test_resize_tokens_embeddings(self):
|
||||
super().test_resize_tokens_embeddings()
|
||||
|
||||
@unittest.skip("This test is currently broken because of safetensors.")
|
||||
def test_tf_from_pt_safetensors(self):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MobileBertModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MobileBertConfig, hidden_size=37)
|
||||
|
||||
@@ -246,7 +246,7 @@ class MPNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mpnet_for_question_answering(*config_and_inputs)
|
||||
|
||||
@unittest.skip("This isn't passing but should, seems like a misconfiguration of tied weights.")
|
||||
@unittest.skip("TFMPNet adds poolers to all models, unlike the PT model class.")
|
||||
def test_tf_from_pt_safetensors(self):
|
||||
return
|
||||
|
||||
|
||||
@@ -196,10 +196,6 @@ class Speech2Text2StandaloneDecoderModelTest(
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("This test is currently broken because of safetensors.")
|
||||
def test_tf_from_pt_safetensors(self):
|
||||
pass
|
||||
|
||||
# speech2text2 has no base model
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@@ -357,10 +357,6 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
def test_model_parallelism(self):
|
||||
super().test_model_parallelism()
|
||||
|
||||
@unittest.skip("This test is currently broken because of safetensors.")
|
||||
def test_tf_from_pt_safetensors(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class XGLMModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user