Implemented safetensors checkpoints save/load for Trainer (#22498)
* implemented safetensors save/load * remove duplicated file * added tests * more tests * style fix * fix tf tests * change to list comprehension Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * review fixes + safe load for sharded checkpoint * style fix * remove rogue import * remove partial to avoid undefined exception * use naming alias instead of safetensors.torch * fix safe sharding in tests * grammar Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * update docs Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * update docs Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * minor corrections * style --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
00b5887b94
commit
871598be55
@@ -42,6 +42,7 @@ from .utils import (
|
||||
get_full_repo_name,
|
||||
is_accelerate_available,
|
||||
is_psutil_available,
|
||||
is_safetensors_available,
|
||||
is_sagemaker_dp_enabled,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_torch_available,
|
||||
@@ -261,6 +262,9 @@ class TrainingArguments:
|
||||
save_total_limit (`int`, *optional*):
|
||||
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
|
||||
`output_dir`.
|
||||
save_safetensors (`bool`, *optional*, defaults to `False`):
|
||||
Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of
|
||||
default `torch.load` and `torch.save`.
|
||||
save_on_each_node (`bool`, *optional*, defaults to `False`):
|
||||
When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on
|
||||
the main one.
|
||||
@@ -720,6 +724,12 @@ class TrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
save_safetensors: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save."
|
||||
},
|
||||
)
|
||||
save_on_each_node: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@@ -1166,6 +1176,17 @@ class TrainingArguments:
|
||||
f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}."
|
||||
)
|
||||
|
||||
safetensors_available = is_safetensors_available()
|
||||
if self.save_safetensors and not safetensors_available:
|
||||
raise ValueError(f"--save_safetensors={self.save_safetensors} requires safetensors to be installed!")
|
||||
if not self.save_safetensors and safetensors_available:
|
||||
logger.info(
|
||||
f"Found safetensors installation, but --save_safetensors={self.save_safetensors}. "
|
||||
f"Safetensors should be a preferred weights saving format due to security and performance reasons. "
|
||||
f"If your model cannot be saved by safetensors please feel free to open an issue at "
|
||||
f"https://github.com/huggingface/safetensors!"
|
||||
)
|
||||
|
||||
if self.load_best_model_at_end and self.metric_for_best_model is None:
|
||||
self.metric_for_best_model = "loss"
|
||||
if self.greater_is_better is None and self.metric_for_best_model is not None:
|
||||
|
||||
Reference in New Issue
Block a user