Make Trainer compatible with sharded checkpoints (#17053)
* Make Trainer compatible with sharded checkpoints * Add doc
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
import dataclasses
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
@@ -65,7 +66,7 @@ from transformers.testing_utils import (
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
|
||||
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
|
||||
from transformers.utils.hp_naming import TrialShortNamer
|
||||
|
||||
|
||||
@@ -376,6 +377,25 @@ class TrainerIntegrationCommon:
|
||||
_ = log1.pop(key, None)
|
||||
self.assertEqual(log, log1)
|
||||
|
||||
def convert_to_sharded_checkpoint(self, folder):
|
||||
# Converts a checkpoint of a regression model to a sharded checkpoint.
|
||||
state_dict = torch.load(os.path.join(folder, WEIGHTS_NAME))
|
||||
os.remove(os.path.join(folder, WEIGHTS_NAME))
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
shard_files = [
|
||||
WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(keys):05d}.bin") for idx in range(len(keys))
|
||||
]
|
||||
index = {"metadata": {}, "weight_map": {key: shard_files[i] for i, key in enumerate(keys)}}
|
||||
|
||||
save_index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
for param_name, shard_file in zip(keys, shard_files):
|
||||
torch.save({param_name: state_dict[param_name]}, os.path.join(folder, shard_file))
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@@ -1038,6 +1058,31 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
trainer.train(resume_from_checkpoint=False)
|
||||
|
||||
@require_torch_up_to_2_gpus
|
||||
def test_resume_training_with_shard_checkpoint(self):
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
# won't be the same since the training dataloader is shuffled).
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
trainer.train()
|
||||
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state = dataclasses.asdict(trainer.state)
|
||||
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||
self.convert_to_sharded_checkpoint(checkpoint)
|
||||
|
||||
# Reinitialize trainer
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
|
||||
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
@require_torch_up_to_2_gpus
|
||||
def test_resume_training_with_gradient_accumulation(self):
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
|
||||
Reference in New Issue
Block a user