Implemented safetensors checkpoints save/load for Trainer (#22498)

* implemented safetensors save/load

* remove duplicated file

* added tests

* more tests

* style fix

* fix tf tests

* change to list comprehension

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* review fixes + safe load for sharded checkpoint

* style fix

* remove rogue import

* remove partial to avoid undefined exception

* use naming alias instead of safetensors.torch

* fix safe sharding in tests

* grammar

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* update docs

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* update docs

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* minor corrections

* style

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Viktor Scherbakov
2023-04-04 16:05:04 +03:00
committed by GitHub
parent 00b5887b94
commit 871598be55
4 changed files with 231 additions and 36 deletions

View File

@@ -42,6 +42,7 @@ from .utils import (
get_full_repo_name,
is_accelerate_available,
is_psutil_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_available,
@@ -261,6 +262,9 @@ class TrainingArguments:
save_total_limit (`int`, *optional*):
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
`output_dir`.
save_safetensors (`bool`, *optional*, defaults to `False`):
Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of
default `torch.load` and `torch.save`.
save_on_each_node (`bool`, *optional*, defaults to `False`):
When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on
the main one.
@@ -720,6 +724,12 @@ class TrainingArguments:
)
},
)
save_safetensors: Optional[bool] = field(
default=False,
metadata={
"help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save."
},
)
save_on_each_node: bool = field(
default=False,
metadata={
@@ -1166,6 +1176,17 @@ class TrainingArguments:
f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}."
)
safetensors_available = is_safetensors_available()
if self.save_safetensors and not safetensors_available:
raise ValueError(f"--save_safetensors={self.save_safetensors} requires safetensors to be installed!")
if not self.save_safetensors and safetensors_available:
logger.info(
f"Found safetensors installation, but --save_safetensors={self.save_safetensors}. "
f"Safetensors should be a preferred weights saving format due to security and performance reasons. "
f"If your model cannot be saved by safetensors please feel free to open an issue at "
f"https://github.com/huggingface/safetensors!"
)
if self.load_best_model_at_end and self.metric_for_best_model is None:
self.metric_for_best_model = "loss"
if self.greater_is_better is None and self.metric_for_best_model is not None: