From 871598be552c38537bc047a409b4a6840ba1c1e4 Mon Sep 17 00:00:00 2001 From: Viktor Scherbakov Date: Tue, 4 Apr 2023 16:05:04 +0300 Subject: [PATCH] Implemented safetensors checkpoints save/load for Trainer (#22498) * implemented safetensors save/load * remove duplicated file * added tests * more tests * style fix * fix tf tests * change to list comprehension Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * review fixes + safe load for sharded checkpoint * style fix * remove rogue import * remove partial to avoid undefined exception * use naming alias instead of safetensors.torch * fix safe sharding in tests * grammar Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * update docs Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * update docs Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * minor corrections * style --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/modeling_utils.py | 39 +++++++-- src/transformers/trainer.py | 74 +++++++++++----- src/transformers/training_args.py | 21 +++++ tests/trainer/test_trainer.py | 133 ++++++++++++++++++++++++++--- 4 files changed, 231 insertions(+), 36 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9a6c29c27b..27faa25278 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -336,7 +336,7 @@ def shard_checkpoint( return shards, index -def load_sharded_checkpoint(model, folder, strict=True): +def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): """ This is the same as [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict) @@ -350,6 +350,9 @@ def load_sharded_checkpoint(model, folder, strict=True): folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint. strict (`bool`, *optional`, defaults to `True`): Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. + prefer_safe (`bool`, *optional*, defaults to `False`) + If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the + safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible. Returns: `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields @@ -358,10 +361,32 @@ def load_sharded_checkpoint(model, folder, strict=True): """ # Load the index index_file = os.path.join(folder, WEIGHTS_INDEX_NAME) - if not os.path.isfile(index_file): - raise ValueError(f"Can't find a checkpoint index ({WEIGHTS_INDEX_NAME}) in {folder}.") + safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME) - with open(index_file, "r", encoding="utf-8") as f: + index_present = os.path.isfile(index_file) + safe_index_present = os.path.isfile(safe_index_file) + + if not index_present and not (safe_index_present and is_safetensors_available()): + filenames = ( + (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) if is_safetensors_available() else (WEIGHTS_INDEX_NAME,) + ) + raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.") + + load_safe = False + if safe_index_present: + if prefer_safe: + if is_safetensors_available(): + load_safe = True # load safe due to preference + else: + logger.warning( + f"Cannot load sharded checkpoint at {folder} safely since safetensors is not installed!" + ) + elif not index_present: + load_safe = True # load safe since we have no other choice + + load_index = safe_index_file if load_safe else index_file + + with open(load_index, "r", encoding="utf-8") as f: index = json.load(f) shard_files = list(set(index["weight_map"].values())) @@ -381,11 +406,13 @@ def load_sharded_checkpoint(model, folder, strict=True): error_message += f"\nMissing key(s): {str_unexpected_keys}." raise RuntimeError(error_message) + loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu") + for shard_file in shard_files: - state_dict = torch.load(os.path.join(folder, shard_file), map_location="cpu") + state_dict = loader(os.path.join(folder, shard_file)) model.load_state_dict(state_dict, strict=False) - # Make sure memory is fred before we load the next state dict. + # Make sure memory is freed before we load the next state dict. del state_dict gc.collect() diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ab79e94a68..2eb081af7c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -135,6 +135,8 @@ from .trainer_utils import ( from .training_args import OptimizerNames, ParallelMode, TrainingArguments from .utils import ( CONFIG_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, can_return_loss, @@ -145,6 +147,7 @@ from .utils import ( is_datasets_available, is_in_notebook, is_ipex_available, + is_safetensors_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_compile_available, @@ -198,6 +201,10 @@ else: IS_SAGEMAKER_MP_POST_1_10 = False +if is_safetensors_available(): + import safetensors.torch + + skip_first_batches = None if is_accelerate_available(): from accelerate import __version__ as accelerate_version @@ -2091,15 +2098,22 @@ class Trainer: if model is None: model = self.model - if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile( - os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) + config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME) + + weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME) + weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) + safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME) + safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME) + + if not any( + [os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file]] ): raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") logger.info(f"Loading model from {resume_from_checkpoint}.") - if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)): - config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME)) + if os.path.isfile(config_file): + config = PretrainedConfig.from_json_file(config_file) checkpoint_version = config.transformers_version if checkpoint_version is not None and checkpoint_version != __version__: logger.warning( @@ -2108,7 +2122,7 @@ class Trainer: "yield to errors or unwanted behaviors." ) - if os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): + if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file): # If the model is on the GPU, it still works! if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): @@ -2124,7 +2138,7 @@ class Trainer: logger.warning( "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported." ) - state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") + state_dict = torch.load(weights_file, map_location="cpu") # Required for smp to not auto-translate state_dict from hf to smp (is already smp). state_dict["_smp_is_partial"] = False load_result = model.load_state_dict(state_dict, strict=True) @@ -2132,7 +2146,11 @@ class Trainer: del state_dict else: # We load the model state dict on the CPU to avoid an OOM error. - state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") + if self.args.save_safetensors and os.path.isfile(safe_weights_file): + state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu") + else: + state_dict = torch.load(weights_file, map_location="cpu") + # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # which takes *args instead of **kwargs load_result = model.load_state_dict(state_dict, False) @@ -2141,15 +2159,18 @@ class Trainer: self._issue_warnings_after_load(load_result) else: # We load the sharded checkpoint - load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled()) + load_result = load_sharded_checkpoint( + model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors + ) if not is_sagemaker_mp_enabled(): self._issue_warnings_after_load(load_result) def _load_best_model(self): logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) + best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - if os.path.exists(best_model_path): + if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path): if self.deepspeed: if self.model_wrapped is not None: # this removes the pre-hooks from the previous engine @@ -2181,12 +2202,20 @@ class Trainer: else: # If the 'user_content.pt' file does NOT exist, load with the old smp api. # Checkpoint must have been saved with the old smp api. - state_dict = torch.load(best_model_path, map_location="cpu") + if self.args.save_safetensors and os.path.isfile(best_safe_model_path): + state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") + else: + state_dict = torch.load(best_model_path, map_location="cpu") + state_dict["_smp_is_partial"] = False load_result = model.load_state_dict(state_dict, strict=True) else: # We load the model state dict on the CPU to avoid an OOM error. - state_dict = torch.load(best_model_path, map_location="cpu") + if self.args.save_safetensors and os.path.isfile(best_safe_model_path): + state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") + else: + state_dict = torch.load(best_model_path, map_location="cpu") + # If the model is on the GPU, it still works! # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # which takes *args instead of **kwargs @@ -2837,17 +2866,24 @@ class Trainer: # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, PreTrainedModel): + if state_dict is None: + state_dict = self.model.state_dict() + if isinstance(unwrap_model(self.model), PreTrainedModel): - if state_dict is None: - state_dict = self.model.state_dict() - unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict) + unwrap_model(self.model).save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") - if state_dict is None: - state_dict = self.model.state_dict() - torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + if self.args.save_safetensors: + safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME)) + else: + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - self.model.save_pretrained(output_dir, state_dict=state_dict) + self.model.save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) @@ -3546,7 +3582,7 @@ class Trainer: output_dir = self.args.output_dir # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder - modeling_files = [CONFIG_NAME, WEIGHTS_NAME] + modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME] for modeling_file in modeling_files: if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 2a3c326732..28387885de 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -42,6 +42,7 @@ from .utils import ( get_full_repo_name, is_accelerate_available, is_psutil_available, + is_safetensors_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_available, @@ -261,6 +262,9 @@ class TrainingArguments: save_total_limit (`int`, *optional*): If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in `output_dir`. + save_safetensors (`bool`, *optional*, defaults to `False`): + Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of + default `torch.load` and `torch.save`. save_on_each_node (`bool`, *optional*, defaults to `False`): When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on the main one. @@ -720,6 +724,12 @@ class TrainingArguments: ) }, ) + save_safetensors: Optional[bool] = field( + default=False, + metadata={ + "help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save." + }, + ) save_on_each_node: bool = field( default=False, metadata={ @@ -1166,6 +1176,17 @@ class TrainingArguments: f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}." ) + safetensors_available = is_safetensors_available() + if self.save_safetensors and not safetensors_available: + raise ValueError(f"--save_safetensors={self.save_safetensors} requires safetensors to be installed!") + if not self.save_safetensors and safetensors_available: + logger.info( + f"Found safetensors installation, but --save_safetensors={self.save_safetensors}. " + f"Safetensors should be a preferred weights saving format due to security and performance reasons. " + f"If your model cannot be saved by safetensors please feel free to open an issue at " + f"https://github.com/huggingface/safetensors!" + ) + if self.load_best_model_at_end and self.metric_for_best_model is None: self.metric_for_best_model = "loss" if self.greater_is_better is None and self.metric_for_best_model is not None: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 310842713b..78b6afeacd 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -25,6 +25,7 @@ import sys import tempfile import time import unittest +from itertools import product from pathlib import Path from unittest.mock import Mock, patch @@ -54,6 +55,7 @@ from transformers.testing_utils import ( require_intel_extension_for_pytorch, require_optuna, require_ray, + require_safetensors, require_sentencepiece, require_sigopt, require_tokenizers, @@ -73,10 +75,13 @@ from transformers.testing_utils import ( from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.training_args import OptimizerNames from transformers.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available, + is_safetensors_available, is_torchdistx_available, ) from transformers.utils.hp_naming import TrialShortNamer @@ -102,6 +107,9 @@ if is_torch_available(): ) from transformers.modeling_utils import unwrap_model + if is_safetensors_available(): + import safetensors.torch + PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" @@ -345,8 +353,9 @@ if is_torch_available(): class TrainerIntegrationCommon: - def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True): - file_list = [WEIGHTS_NAME, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"] + def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=False): + weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME + file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"] if is_pretrained: file_list.append("config.json") for step in range(freq, total, freq): @@ -356,7 +365,7 @@ class TrainerIntegrationCommon: self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename))) def check_best_model_has_been_loaded( - self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True + self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=False ): checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}") log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history @@ -370,7 +379,10 @@ class TrainerIntegrationCommon: best_model.to(trainer.args.device) else: best_model = RegressionModel() - state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME)) + if not safe_weights: + state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME)) + else: + state_dict = safetensors.torch.load_file(os.path.join(checkpoint, SAFE_WEIGHTS_NAME)) best_model.load_state_dict(state_dict) best_model.to(trainer.args.device) self.assertTrue(torch.allclose(best_model.a, trainer.model.a)) @@ -394,24 +406,43 @@ class TrainerIntegrationCommon: _ = log1.pop(key, None) self.assertEqual(log, log1) - def convert_to_sharded_checkpoint(self, folder): + def convert_to_sharded_checkpoint(self, folder, save_safe=False, load_safe=False): # Converts a checkpoint of a regression model to a sharded checkpoint. - state_dict = torch.load(os.path.join(folder, WEIGHTS_NAME)) - os.remove(os.path.join(folder, WEIGHTS_NAME)) + if load_safe: + loader = safetensors.torch.load_file + weights_file = os.path.join(folder, SAFE_WEIGHTS_NAME) + else: + loader = torch.load + weights_file = os.path.join(folder, WEIGHTS_NAME) + + if save_safe: + extension = "safetensors" + saver = safetensors.torch.save_file + index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME) + shard_name = SAFE_WEIGHTS_NAME + else: + extension = "bin" + saver = torch.save + index_file = os.path.join(folder, WEIGHTS_INDEX_NAME) + shard_name = WEIGHTS_NAME + + state_dict = loader(weights_file) + + os.remove(weights_file) keys = list(state_dict.keys()) shard_files = [ - WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(keys):05d}.bin") for idx in range(len(keys)) + shard_name.replace(f".{extension}", f"-{idx+1:05d}-of-{len(keys):05d}.{extension}") + for idx in range(len(keys)) ] index = {"metadata": {}, "weight_map": {key: shard_files[i] for i, key in enumerate(keys)}} - save_index_file = os.path.join(folder, WEIGHTS_INDEX_NAME) - with open(save_index_file, "w", encoding="utf-8") as f: + with open(index_file, "w", encoding="utf-8") as f: content = json.dumps(index, indent=2, sort_keys=True) + "\n" f.write(content) for param_name, shard_file in zip(keys, shard_files): - torch.save({param_name: state_dict[param_name]}, os.path.join(folder, shard_file)) + saver({param_name: state_dict[param_name]}, os.path.join(folder, shard_file)) @require_torch @@ -1132,6 +1163,26 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): trainer.train() self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False) + @require_safetensors + def test_safe_checkpoints(self): + for save_safetensors in [True, False]: + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5, save_safetensors=save_safetensors) + trainer.train() + self.check_saved_checkpoints( + tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), safe_weights=save_safetensors + ) + + # With a regular model that is not a PreTrainedModel + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + output_dir=tmpdir, save_steps=5, pretrained=False, save_safetensors=save_safetensors + ) + trainer.train() + self.check_saved_checkpoints( + tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors + ) + @require_torch_multi_gpu def test_run_seq2seq_double_train_wrap_once(self): # test that we don't wrap the model more than once @@ -1373,6 +1424,42 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertEqual(b, b1) self.check_trainer_state_are_the_same(state, state1) + @require_safetensors + @require_torch_up_to_2_gpus + def test_resume_training_with_safe_checkpoint(self): + # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of + # save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model + # won't be the same since the training dataloader is shuffled). + + for initial_safe in [False, True]: + for loaded_safe in [False, True]: + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + output_dir=tmpdir, + train_len=128, + save_steps=5, + learning_rate=0.1, + save_safetensors=initial_safe, + ) + trainer.train() + (a, b) = trainer.model.a.item(), trainer.model.b.item() + state = dataclasses.asdict(trainer.state) + + checkpoint = os.path.join(tmpdir, "checkpoint-5") + self.convert_to_sharded_checkpoint(checkpoint, load_safe=initial_safe, save_safe=loaded_safe) + + # Reinitialize trainer + trainer = get_regression_trainer( + output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, save_safetensors=loaded_safe + ) + + trainer.train(resume_from_checkpoint=checkpoint) + (a1, b1) = trainer.model.a.item(), trainer.model.b.item() + state1 = dataclasses.asdict(trainer.state) + self.assertEqual(a, a1) + self.assertEqual(b, b1) + self.check_trainer_state_are_the_same(state, state1) + @require_torch_up_to_2_gpus def test_resume_training_with_gradient_accumulation(self): # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of @@ -1522,6 +1609,30 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False) self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss", is_pretrained=False) + @require_safetensors + def test_load_best_model_from_safetensors(self): + total = int(self.n_epochs * 64 / self.batch_size) + for save_safetensors, pretrained in product([False, True], [False, True]): + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer( + a=1.5, + b=2.5, + output_dir=tmpdir, + learning_rate=0.1, + eval_steps=5, + evaluation_strategy="steps", + save_steps=5, + load_best_model_at_end=True, + save_safetensors=save_safetensors, + pretrained=pretrained, + ) + self.assertFalse(trainer.args.greater_is_better) + trainer.train() + self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=pretrained, safe_weights=save_safetensors) + self.check_best_model_has_been_loaded( + tmpdir, 5, total, trainer, "eval_loss", is_pretrained=pretrained, safe_weights=save_safetensors + ) + @slow def test_trainer_eval_mrpc(self): MODEL_ID = "bert-base-cased-finetuned-mrpc"