Implemented safetensors checkpoints save/load for Trainer (#22498)

* implemented safetensors save/load

* remove duplicated file

* added tests

* more tests

* style fix

* fix tf tests

* change to list comprehension

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* review fixes + safe load for sharded checkpoint

* style fix

* remove rogue import

* remove partial to avoid undefined exception

* use naming alias instead of safetensors.torch

* fix safe sharding in tests

* grammar

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* update docs

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* update docs

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* minor corrections

* style

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Viktor Scherbakov
2023-04-04 16:05:04 +03:00
committed by GitHub
parent 00b5887b94
commit 871598be55
4 changed files with 231 additions and 36 deletions

View File

@@ -25,6 +25,7 @@ import sys
import tempfile
import time
import unittest
from itertools import product
from pathlib import Path
from unittest.mock import Mock, patch
@@ -54,6 +55,7 @@ from transformers.testing_utils import (
require_intel_extension_for_pytorch,
require_optuna,
require_ray,
require_safetensors,
require_sentencepiece,
require_sigopt,
require_tokenizers,
@@ -73,10 +75,13 @@ from transformers.testing_utils import (
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.training_args import OptimizerNames
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_apex_available,
is_bitsandbytes_available,
is_safetensors_available,
is_torchdistx_available,
)
from transformers.utils.hp_naming import TrialShortNamer
@@ -102,6 +107,9 @@ if is_torch_available():
)
from transformers.modeling_utils import unwrap_model
if is_safetensors_available():
import safetensors.torch
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
@@ -345,8 +353,9 @@ if is_torch_available():
class TrainerIntegrationCommon:
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
file_list = [WEIGHTS_NAME, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=False):
weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME
file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
if is_pretrained:
file_list.append("config.json")
for step in range(freq, total, freq):
@@ -356,7 +365,7 @@ class TrainerIntegrationCommon:
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
def check_best_model_has_been_loaded(
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=False
):
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history
@@ -370,7 +379,10 @@ class TrainerIntegrationCommon:
best_model.to(trainer.args.device)
else:
best_model = RegressionModel()
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
if not safe_weights:
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
else:
state_dict = safetensors.torch.load_file(os.path.join(checkpoint, SAFE_WEIGHTS_NAME))
best_model.load_state_dict(state_dict)
best_model.to(trainer.args.device)
self.assertTrue(torch.allclose(best_model.a, trainer.model.a))
@@ -394,24 +406,43 @@ class TrainerIntegrationCommon:
_ = log1.pop(key, None)
self.assertEqual(log, log1)
def convert_to_sharded_checkpoint(self, folder):
def convert_to_sharded_checkpoint(self, folder, save_safe=False, load_safe=False):
# Converts a checkpoint of a regression model to a sharded checkpoint.
state_dict = torch.load(os.path.join(folder, WEIGHTS_NAME))
os.remove(os.path.join(folder, WEIGHTS_NAME))
if load_safe:
loader = safetensors.torch.load_file
weights_file = os.path.join(folder, SAFE_WEIGHTS_NAME)
else:
loader = torch.load
weights_file = os.path.join(folder, WEIGHTS_NAME)
if save_safe:
extension = "safetensors"
saver = safetensors.torch.save_file
index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
shard_name = SAFE_WEIGHTS_NAME
else:
extension = "bin"
saver = torch.save
index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
shard_name = WEIGHTS_NAME
state_dict = loader(weights_file)
os.remove(weights_file)
keys = list(state_dict.keys())
shard_files = [
WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(keys):05d}.bin") for idx in range(len(keys))
shard_name.replace(f".{extension}", f"-{idx+1:05d}-of-{len(keys):05d}.{extension}")
for idx in range(len(keys))
]
index = {"metadata": {}, "weight_map": {key: shard_files[i] for i, key in enumerate(keys)}}
save_index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
with open(save_index_file, "w", encoding="utf-8") as f:
with open(index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
for param_name, shard_file in zip(keys, shard_files):
torch.save({param_name: state_dict[param_name]}, os.path.join(folder, shard_file))
saver({param_name: state_dict[param_name]}, os.path.join(folder, shard_file))
@require_torch
@@ -1132,6 +1163,26 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
@require_safetensors
def test_safe_checkpoints(self):
for save_safetensors in [True, False]:
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5, save_safetensors=save_safetensors)
trainer.train()
self.check_saved_checkpoints(
tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), safe_weights=save_safetensors
)
# With a regular model that is not a PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir, save_steps=5, pretrained=False, save_safetensors=save_safetensors
)
trainer.train()
self.check_saved_checkpoints(
tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors
)
@require_torch_multi_gpu
def test_run_seq2seq_double_train_wrap_once(self):
# test that we don't wrap the model more than once
@@ -1373,6 +1424,42 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
@require_safetensors
@require_torch_up_to_2_gpus
def test_resume_training_with_safe_checkpoint(self):
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
for initial_safe in [False, True]:
for loaded_safe in [False, True]:
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
train_len=128,
save_steps=5,
learning_rate=0.1,
save_safetensors=initial_safe,
)
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state)
checkpoint = os.path.join(tmpdir, "checkpoint-5")
self.convert_to_sharded_checkpoint(checkpoint, load_safe=initial_safe, save_safe=loaded_safe)
# Reinitialize trainer
trainer = get_regression_trainer(
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, save_safetensors=loaded_safe
)
trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
@require_torch_up_to_2_gpus
def test_resume_training_with_gradient_accumulation(self):
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
@@ -1522,6 +1609,30 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False)
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss", is_pretrained=False)
@require_safetensors
def test_load_best_model_from_safetensors(self):
total = int(self.n_epochs * 64 / self.batch_size)
for save_safetensors, pretrained in product([False, True], [False, True]):
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
a=1.5,
b=2.5,
output_dir=tmpdir,
learning_rate=0.1,
eval_steps=5,
evaluation_strategy="steps",
save_steps=5,
load_best_model_at_end=True,
save_safetensors=save_safetensors,
pretrained=pretrained,
)
self.assertFalse(trainer.args.greater_is_better)
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=pretrained, safe_weights=save_safetensors)
self.check_best_model_has_been_loaded(
tmpdir, 5, total, trainer, "eval_loss", is_pretrained=pretrained, safe_weights=save_safetensors
)
@slow
def test_trainer_eval_mrpc(self):
MODEL_ID = "bert-base-cased-finetuned-mrpc"