From 716819b8309324302e00a3488a3c3d6faa427f79 Mon Sep 17 00:00:00 2001 From: Arjuna Sky Kok <32124593+arjunaskykok@users.noreply.github.com> Date: Sat, 10 May 2025 18:11:07 +0700 Subject: [PATCH] fix(conversion): Fix size mismatch error during TF->PT model loading (#38014) --- src/transformers/modeling_tf_pytorch_utils.py | 4 ++-- tests/utils/test_modeling_utils.py | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 51c21bb7fa..24bdf4faa0 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -585,8 +585,8 @@ def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_ loaded_pt_weights_data_ptr = {} missing_keys_pt = [] for pt_weight_name, pt_weight in current_pt_params_dict.items(): - # Handle PyTorch shared weight ()not duplicated in TF 2.0 - if pt_weight.data_ptr() in loaded_pt_weights_data_ptr: + # Handle PyTorch shared weight not duplicated in TF 2.0 + if pt_weight.data_ptr() in loaded_pt_weights_data_ptr and pt_weight.data_ptr() != 0: new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()] continue diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 2df3384963..a5aef44c38 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1571,6 +1571,14 @@ class ModelUtilsTest(TestCasePlus): for p1, p2 in zip(hub_model.parameters(), new_model.parameters()): self.assertTrue(torch.equal(p1, p2)) + @require_tf + def test_torch_from_tf(self): + model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only") + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + _ = BertModel.from_pretrained(tmp_dir, from_tf=True) + @require_safetensors def test_safetensors_torch_from_torch_sharded(self): model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")