Fix TF loading PT safetensors when weights are tied (#27490)
* Un-skip tests * Add aliasing support to tf_to_pt_weight_rename * Refactor tf-to-pt weight rename for simplicity * Patch mobilebert * Let us pray that the transfo-xl one works * Add XGLM rename * Expand the test to see if we can get more models to break * Expand the test to see if we can get more models to break * Fix MPNet (it was actually an unrelated bug) * Fix MPNet (it was actually an unrelated bug) * Add speech2text fix * Update src/transformers/modeling_tf_pytorch_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/mobilebert/modeling_tf_mobilebert.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update to always return a tuple from tf_to_pt_weight_rename * reformat * Add a couple of missing tuples * Remove the extra test for tie_word_embeddings since it didn't cause any unexpected failures anyway * Revert changes to modeling_tf_mpnet.py * Skip MPNet test and add explanation * Add weight link for BART * Add TODO to clean this up a bit --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -357,10 +357,6 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
def test_model_parallelism(self):
|
||||
super().test_model_parallelism()
|
||||
|
||||
@unittest.skip("This test is currently broken because of safetensors.")
|
||||
def test_tf_from_pt_safetensors(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class XGLMModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user