Implemented safetensors checkpoints save/load for Trainer (#22498)
* implemented safetensors save/load * remove duplicated file * added tests * more tests * style fix * fix tf tests * change to list comprehension Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * review fixes + safe load for sharded checkpoint * style fix * remove rogue import * remove partial to avoid undefined exception * use naming alias instead of safetensors.torch * fix safe sharding in tests * grammar Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * update docs Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * update docs Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * minor corrections * style --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
00b5887b94
commit
871598be55
@@ -336,7 +336,7 @@ def shard_checkpoint(
|
||||
return shards, index
|
||||
|
||||
|
||||
def load_sharded_checkpoint(model, folder, strict=True):
|
||||
def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
"""
|
||||
This is the same as
|
||||
[`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
|
||||
@@ -350,6 +350,9 @@ def load_sharded_checkpoint(model, folder, strict=True):
|
||||
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
|
||||
strict (`bool`, *optional`, defaults to `True`):
|
||||
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
|
||||
prefer_safe (`bool`, *optional*, defaults to `False`)
|
||||
If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
|
||||
safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
|
||||
|
||||
Returns:
|
||||
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
|
||||
@@ -358,10 +361,32 @@ def load_sharded_checkpoint(model, folder, strict=True):
|
||||
"""
|
||||
# Load the index
|
||||
index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
|
||||
if not os.path.isfile(index_file):
|
||||
raise ValueError(f"Can't find a checkpoint index ({WEIGHTS_INDEX_NAME}) in {folder}.")
|
||||
safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
|
||||
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index_present = os.path.isfile(index_file)
|
||||
safe_index_present = os.path.isfile(safe_index_file)
|
||||
|
||||
if not index_present and not (safe_index_present and is_safetensors_available()):
|
||||
filenames = (
|
||||
(WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) if is_safetensors_available() else (WEIGHTS_INDEX_NAME,)
|
||||
)
|
||||
raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")
|
||||
|
||||
load_safe = False
|
||||
if safe_index_present:
|
||||
if prefer_safe:
|
||||
if is_safetensors_available():
|
||||
load_safe = True # load safe due to preference
|
||||
else:
|
||||
logger.warning(
|
||||
f"Cannot load sharded checkpoint at {folder} safely since safetensors is not installed!"
|
||||
)
|
||||
elif not index_present:
|
||||
load_safe = True # load safe since we have no other choice
|
||||
|
||||
load_index = safe_index_file if load_safe else index_file
|
||||
|
||||
with open(load_index, "r", encoding="utf-8") as f:
|
||||
index = json.load(f)
|
||||
|
||||
shard_files = list(set(index["weight_map"].values()))
|
||||
@@ -381,11 +406,13 @@ def load_sharded_checkpoint(model, folder, strict=True):
|
||||
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu")
|
||||
|
||||
for shard_file in shard_files:
|
||||
state_dict = torch.load(os.path.join(folder, shard_file), map_location="cpu")
|
||||
state_dict = loader(os.path.join(folder, shard_file))
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Make sure memory is fred before we load the next state dict.
|
||||
# Make sure memory is freed before we load the next state dict.
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
|
||||
@@ -135,6 +135,8 @@ from .trainer_utils import (
|
||||
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
can_return_loss,
|
||||
@@ -145,6 +147,7 @@ from .utils import (
|
||||
is_datasets_available,
|
||||
is_in_notebook,
|
||||
is_ipex_available,
|
||||
is_safetensors_available,
|
||||
is_sagemaker_dp_enabled,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_torch_compile_available,
|
||||
@@ -198,6 +201,10 @@ else:
|
||||
IS_SAGEMAKER_MP_POST_1_10 = False
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors.torch
|
||||
|
||||
|
||||
skip_first_batches = None
|
||||
if is_accelerate_available():
|
||||
from accelerate import __version__ as accelerate_version
|
||||
@@ -2091,15 +2098,22 @@ class Trainer:
|
||||
if model is None:
|
||||
model = self.model
|
||||
|
||||
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
|
||||
os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
|
||||
config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
|
||||
|
||||
weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
|
||||
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
|
||||
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
|
||||
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
|
||||
|
||||
if not any(
|
||||
[os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file]]
|
||||
):
|
||||
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
|
||||
|
||||
logger.info(f"Loading model from {resume_from_checkpoint}.")
|
||||
|
||||
if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
|
||||
config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
|
||||
if os.path.isfile(config_file):
|
||||
config = PretrainedConfig.from_json_file(config_file)
|
||||
checkpoint_version = config.transformers_version
|
||||
if checkpoint_version is not None and checkpoint_version != __version__:
|
||||
logger.warning(
|
||||
@@ -2108,7 +2122,7 @@ class Trainer:
|
||||
"yield to errors or unwanted behaviors."
|
||||
)
|
||||
|
||||
if os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
|
||||
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):
|
||||
# If the model is on the GPU, it still works!
|
||||
if is_sagemaker_mp_enabled():
|
||||
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
|
||||
@@ -2124,7 +2138,7 @@ class Trainer:
|
||||
logger.warning(
|
||||
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
|
||||
)
|
||||
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
|
||||
state_dict = torch.load(weights_file, map_location="cpu")
|
||||
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
|
||||
state_dict["_smp_is_partial"] = False
|
||||
load_result = model.load_state_dict(state_dict, strict=True)
|
||||
@@ -2132,7 +2146,11 @@ class Trainer:
|
||||
del state_dict
|
||||
else:
|
||||
# We load the model state dict on the CPU to avoid an OOM error.
|
||||
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
|
||||
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
|
||||
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(weights_file, map_location="cpu")
|
||||
|
||||
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
|
||||
# which takes *args instead of **kwargs
|
||||
load_result = model.load_state_dict(state_dict, False)
|
||||
@@ -2141,15 +2159,18 @@ class Trainer:
|
||||
self._issue_warnings_after_load(load_result)
|
||||
else:
|
||||
# We load the sharded checkpoint
|
||||
load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled())
|
||||
load_result = load_sharded_checkpoint(
|
||||
model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
|
||||
)
|
||||
if not is_sagemaker_mp_enabled():
|
||||
self._issue_warnings_after_load(load_result)
|
||||
|
||||
def _load_best_model(self):
|
||||
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
||||
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
|
||||
best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
|
||||
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if os.path.exists(best_model_path):
|
||||
if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path):
|
||||
if self.deepspeed:
|
||||
if self.model_wrapped is not None:
|
||||
# this removes the pre-hooks from the previous engine
|
||||
@@ -2181,12 +2202,20 @@ class Trainer:
|
||||
else:
|
||||
# If the 'user_content.pt' file does NOT exist, load with the old smp api.
|
||||
# Checkpoint must have been saved with the old smp api.
|
||||
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
|
||||
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(best_model_path, map_location="cpu")
|
||||
|
||||
state_dict["_smp_is_partial"] = False
|
||||
load_result = model.load_state_dict(state_dict, strict=True)
|
||||
else:
|
||||
# We load the model state dict on the CPU to avoid an OOM error.
|
||||
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
|
||||
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(best_model_path, map_location="cpu")
|
||||
|
||||
# If the model is on the GPU, it still works!
|
||||
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
|
||||
# which takes *args instead of **kwargs
|
||||
@@ -2837,17 +2866,24 @@ class Trainer:
|
||||
# Save a trained model and configuration using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
if not isinstance(self.model, PreTrainedModel):
|
||||
if isinstance(unwrap_model(self.model), PreTrainedModel):
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
|
||||
|
||||
if isinstance(unwrap_model(self.model), PreTrainedModel):
|
||||
unwrap_model(self.model).save_pretrained(
|
||||
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
else:
|
||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
if self.args.save_safetensors:
|
||||
safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))
|
||||
else:
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
self.model.save_pretrained(output_dir, state_dict=state_dict)
|
||||
self.model.save_pretrained(
|
||||
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
@@ -3546,7 +3582,7 @@ class Trainer:
|
||||
|
||||
output_dir = self.args.output_dir
|
||||
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
|
||||
modeling_files = [CONFIG_NAME, WEIGHTS_NAME]
|
||||
modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
|
||||
for modeling_file in modeling_files:
|
||||
if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
|
||||
shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
|
||||
|
||||
@@ -42,6 +42,7 @@ from .utils import (
|
||||
get_full_repo_name,
|
||||
is_accelerate_available,
|
||||
is_psutil_available,
|
||||
is_safetensors_available,
|
||||
is_sagemaker_dp_enabled,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_torch_available,
|
||||
@@ -261,6 +262,9 @@ class TrainingArguments:
|
||||
save_total_limit (`int`, *optional*):
|
||||
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
|
||||
`output_dir`.
|
||||
save_safetensors (`bool`, *optional*, defaults to `False`):
|
||||
Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of
|
||||
default `torch.load` and `torch.save`.
|
||||
save_on_each_node (`bool`, *optional*, defaults to `False`):
|
||||
When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on
|
||||
the main one.
|
||||
@@ -720,6 +724,12 @@ class TrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
save_safetensors: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save."
|
||||
},
|
||||
)
|
||||
save_on_each_node: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@@ -1166,6 +1176,17 @@ class TrainingArguments:
|
||||
f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}."
|
||||
)
|
||||
|
||||
safetensors_available = is_safetensors_available()
|
||||
if self.save_safetensors and not safetensors_available:
|
||||
raise ValueError(f"--save_safetensors={self.save_safetensors} requires safetensors to be installed!")
|
||||
if not self.save_safetensors and safetensors_available:
|
||||
logger.info(
|
||||
f"Found safetensors installation, but --save_safetensors={self.save_safetensors}. "
|
||||
f"Safetensors should be a preferred weights saving format due to security and performance reasons. "
|
||||
f"If your model cannot be saved by safetensors please feel free to open an issue at "
|
||||
f"https://github.com/huggingface/safetensors!"
|
||||
)
|
||||
|
||||
if self.load_best_model_at_end and self.metric_for_best_model is None:
|
||||
self.metric_for_best_model = "loss"
|
||||
if self.greater_is_better is None and self.metric_for_best_model is not None:
|
||||
|
||||
@@ -25,6 +25,7 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@@ -54,6 +55,7 @@ from transformers.testing_utils import (
|
||||
require_intel_extension_for_pytorch,
|
||||
require_optuna,
|
||||
require_ray,
|
||||
require_safetensors,
|
||||
require_sentencepiece,
|
||||
require_sigopt,
|
||||
require_tokenizers,
|
||||
@@ -73,10 +75,13 @@ from transformers.testing_utils import (
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_apex_available,
|
||||
is_bitsandbytes_available,
|
||||
is_safetensors_available,
|
||||
is_torchdistx_available,
|
||||
)
|
||||
from transformers.utils.hp_naming import TrialShortNamer
|
||||
@@ -102,6 +107,9 @@ if is_torch_available():
|
||||
)
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors.torch
|
||||
|
||||
|
||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||
|
||||
@@ -345,8 +353,9 @@ if is_torch_available():
|
||||
|
||||
|
||||
class TrainerIntegrationCommon:
|
||||
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
|
||||
file_list = [WEIGHTS_NAME, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
|
||||
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=False):
|
||||
weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME
|
||||
file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
|
||||
if is_pretrained:
|
||||
file_list.append("config.json")
|
||||
for step in range(freq, total, freq):
|
||||
@@ -356,7 +365,7 @@ class TrainerIntegrationCommon:
|
||||
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
|
||||
|
||||
def check_best_model_has_been_loaded(
|
||||
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True
|
||||
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=False
|
||||
):
|
||||
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
|
||||
log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history
|
||||
@@ -370,7 +379,10 @@ class TrainerIntegrationCommon:
|
||||
best_model.to(trainer.args.device)
|
||||
else:
|
||||
best_model = RegressionModel()
|
||||
if not safe_weights:
|
||||
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
|
||||
else:
|
||||
state_dict = safetensors.torch.load_file(os.path.join(checkpoint, SAFE_WEIGHTS_NAME))
|
||||
best_model.load_state_dict(state_dict)
|
||||
best_model.to(trainer.args.device)
|
||||
self.assertTrue(torch.allclose(best_model.a, trainer.model.a))
|
||||
@@ -394,24 +406,43 @@ class TrainerIntegrationCommon:
|
||||
_ = log1.pop(key, None)
|
||||
self.assertEqual(log, log1)
|
||||
|
||||
def convert_to_sharded_checkpoint(self, folder):
|
||||
def convert_to_sharded_checkpoint(self, folder, save_safe=False, load_safe=False):
|
||||
# Converts a checkpoint of a regression model to a sharded checkpoint.
|
||||
state_dict = torch.load(os.path.join(folder, WEIGHTS_NAME))
|
||||
os.remove(os.path.join(folder, WEIGHTS_NAME))
|
||||
if load_safe:
|
||||
loader = safetensors.torch.load_file
|
||||
weights_file = os.path.join(folder, SAFE_WEIGHTS_NAME)
|
||||
else:
|
||||
loader = torch.load
|
||||
weights_file = os.path.join(folder, WEIGHTS_NAME)
|
||||
|
||||
if save_safe:
|
||||
extension = "safetensors"
|
||||
saver = safetensors.torch.save_file
|
||||
index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
|
||||
shard_name = SAFE_WEIGHTS_NAME
|
||||
else:
|
||||
extension = "bin"
|
||||
saver = torch.save
|
||||
index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
|
||||
shard_name = WEIGHTS_NAME
|
||||
|
||||
state_dict = loader(weights_file)
|
||||
|
||||
os.remove(weights_file)
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
shard_files = [
|
||||
WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(keys):05d}.bin") for idx in range(len(keys))
|
||||
shard_name.replace(f".{extension}", f"-{idx+1:05d}-of-{len(keys):05d}.{extension}")
|
||||
for idx in range(len(keys))
|
||||
]
|
||||
index = {"metadata": {}, "weight_map": {key: shard_files[i] for i, key in enumerate(keys)}}
|
||||
|
||||
save_index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
with open(index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
for param_name, shard_file in zip(keys, shard_files):
|
||||
torch.save({param_name: state_dict[param_name]}, os.path.join(folder, shard_file))
|
||||
saver({param_name: state_dict[param_name]}, os.path.join(folder, shard_file))
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -1132,6 +1163,26 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
|
||||
|
||||
@require_safetensors
|
||||
def test_safe_checkpoints(self):
|
||||
for save_safetensors in [True, False]:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5, save_safetensors=save_safetensors)
|
||||
trainer.train()
|
||||
self.check_saved_checkpoints(
|
||||
tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), safe_weights=save_safetensors
|
||||
)
|
||||
|
||||
# With a regular model that is not a PreTrainedModel
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir, save_steps=5, pretrained=False, save_safetensors=save_safetensors
|
||||
)
|
||||
trainer.train()
|
||||
self.check_saved_checkpoints(
|
||||
tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors
|
||||
)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_run_seq2seq_double_train_wrap_once(self):
|
||||
# test that we don't wrap the model more than once
|
||||
@@ -1373,6 +1424,42 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
@require_safetensors
|
||||
@require_torch_up_to_2_gpus
|
||||
def test_resume_training_with_safe_checkpoint(self):
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
# won't be the same since the training dataloader is shuffled).
|
||||
|
||||
for initial_safe in [False, True]:
|
||||
for loaded_safe in [False, True]:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
train_len=128,
|
||||
save_steps=5,
|
||||
learning_rate=0.1,
|
||||
save_safetensors=initial_safe,
|
||||
)
|
||||
trainer.train()
|
||||
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state = dataclasses.asdict(trainer.state)
|
||||
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||
self.convert_to_sharded_checkpoint(checkpoint, load_safe=initial_safe, save_safe=loaded_safe)
|
||||
|
||||
# Reinitialize trainer
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, save_safetensors=loaded_safe
|
||||
)
|
||||
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
@require_torch_up_to_2_gpus
|
||||
def test_resume_training_with_gradient_accumulation(self):
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
@@ -1522,6 +1609,30 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False)
|
||||
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss", is_pretrained=False)
|
||||
|
||||
@require_safetensors
|
||||
def test_load_best_model_from_safetensors(self):
|
||||
total = int(self.n_epochs * 64 / self.batch_size)
|
||||
for save_safetensors, pretrained in product([False, True], [False, True]):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
b=2.5,
|
||||
output_dir=tmpdir,
|
||||
learning_rate=0.1,
|
||||
eval_steps=5,
|
||||
evaluation_strategy="steps",
|
||||
save_steps=5,
|
||||
load_best_model_at_end=True,
|
||||
save_safetensors=save_safetensors,
|
||||
pretrained=pretrained,
|
||||
)
|
||||
self.assertFalse(trainer.args.greater_is_better)
|
||||
trainer.train()
|
||||
self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=pretrained, safe_weights=save_safetensors)
|
||||
self.check_best_model_has_been_loaded(
|
||||
tmpdir, 5, total, trainer, "eval_loss", is_pretrained=pretrained, safe_weights=save_safetensors
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_trainer_eval_mrpc(self):
|
||||
MODEL_ID = "bert-base-cased-finetuned-mrpc"
|
||||
|
||||
Reference in New Issue
Block a user