Safetensors serialization by default (#27064)

* Safetensors serialization by default

* First pass on the tests

* Second pass on the tests

* Third pass on the tests

* Fix TF weight loading from TF-format safetensors

* Specific encoder-decoder fixes for weight crossloading

* Add VisionEncoderDecoder fixes for TF too

* Change filename test for pt-to-tf

* One missing fix for TFVisionEncoderDecoder

* Fix the other crossload test

* Support for flax + updated tests

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Sanchit's comments

* Sanchit's comments 2

* Nico's comments

* Fix tests

* cleanup

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: Matt <rocketknight1@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Lysandre Debut
2023-10-31 19:16:49 +01:00
committed by GitHub
parent 25e6e9418c
commit 113ebf80ac
20 changed files with 433 additions and 137 deletions

View File

@@ -403,7 +403,7 @@ if is_torch_available():
class TrainerIntegrationCommon:
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=False):
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=True):
weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME
file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
if is_pretrained:
@@ -415,7 +415,7 @@ class TrainerIntegrationCommon:
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
def check_best_model_has_been_loaded(
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=False
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=True
):
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history
@@ -456,7 +456,7 @@ class TrainerIntegrationCommon:
_ = log1.pop(key, None)
self.assertEqual(log, log1)
def convert_to_sharded_checkpoint(self, folder, save_safe=False, load_safe=False):
def convert_to_sharded_checkpoint(self, folder, save_safe=True, load_safe=True):
# Converts a checkpoint of a regression model to a sharded checkpoint.
if load_safe:
loader = safetensors.torch.load_file