Safetensors serialization by default (#27064)

* Safetensors serialization by default

* First pass on the tests

* Second pass on the tests

* Third pass on the tests

* Fix TF weight loading from TF-format safetensors

* Specific encoder-decoder fixes for weight crossloading

* Add VisionEncoderDecoder fixes for TF too

* Change filename test for pt-to-tf

* One missing fix for TFVisionEncoderDecoder

* Fix the other crossload test

* Support for flax + updated tests

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Sanchit's comments

* Sanchit's comments 2

* Nico's comments

* Fix tests

* cleanup

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: Matt <rocketknight1@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Lysandre Debut
2023-10-31 19:16:49 +01:00
committed by GitHub
parent 25e6e9418c
commit 113ebf80ac
20 changed files with 433 additions and 137 deletions

View File

@@ -43,7 +43,6 @@ class CLITest(unittest.TestCase):
shutil.rmtree("/tmp/hf-internal-testing/tiny-random-gptj", ignore_errors=True) # cleans potential past runs
transformers.commands.transformers_cli.main()
# The original repo has no TF weights -- if they exist, they were created by the CLI
self.assertTrue(os.path.exists("/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5"))
@require_torch