Safetensors serialization by default (#27064)
* Safetensors serialization by default * First pass on the tests * Second pass on the tests * Third pass on the tests * Fix TF weight loading from TF-format safetensors * Specific encoder-decoder fixes for weight crossloading * Add VisionEncoderDecoder fixes for TF too * Change filename test for pt-to-tf * One missing fix for TFVisionEncoderDecoder * Fix the other crossload test * Support for flax + updated tests * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Sanchit's comments * Sanchit's comments 2 * Nico's comments * Fix tests * cleanup * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -403,7 +403,7 @@ if is_torch_available():
|
||||
|
||||
|
||||
class TrainerIntegrationCommon:
|
||||
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=False):
|
||||
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=True):
|
||||
weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME
|
||||
file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
|
||||
if is_pretrained:
|
||||
@@ -415,7 +415,7 @@ class TrainerIntegrationCommon:
|
||||
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
|
||||
|
||||
def check_best_model_has_been_loaded(
|
||||
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=False
|
||||
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=True
|
||||
):
|
||||
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
|
||||
log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history
|
||||
@@ -456,7 +456,7 @@ class TrainerIntegrationCommon:
|
||||
_ = log1.pop(key, None)
|
||||
self.assertEqual(log, log1)
|
||||
|
||||
def convert_to_sharded_checkpoint(self, folder, save_safe=False, load_safe=False):
|
||||
def convert_to_sharded_checkpoint(self, folder, save_safe=True, load_safe=True):
|
||||
# Converts a checkpoint of a regression model to a sharded checkpoint.
|
||||
if load_safe:
|
||||
loader = safetensors.torch.load_file
|
||||
|
||||
Reference in New Issue
Block a user