Load checkpoint without re-creating the model (#11318)
This commit is contained in:
@@ -271,7 +271,7 @@ class PretrainedConfig(object):
|
|||||||
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
||||||
|
|
||||||
# Drop the transformers version info
|
# Drop the transformers version info
|
||||||
kwargs.pop("transformers_version", None)
|
self.transformers_version = kwargs.pop("transformers_version", None)
|
||||||
|
|
||||||
# Additional attributes without default values
|
# Additional attributes without default values
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
|
|||||||
@@ -55,9 +55,12 @@ from torch.utils.data.dataset import Dataset, IterableDataset
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||||
|
|
||||||
|
from . import __version__
|
||||||
|
from .configuration_utils import PretrainedConfig
|
||||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||||
from .dependency_versions_check import dep_version_check
|
from .dependency_versions_check import dep_version_check
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
|
CONFIG_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
is_apex_available,
|
is_apex_available,
|
||||||
is_datasets_available,
|
is_datasets_available,
|
||||||
@@ -999,14 +1002,23 @@ class Trainer:
|
|||||||
|
|
||||||
logger.info(f"Loading model from {resume_from_checkpoint}).")
|
logger.info(f"Loading model from {resume_from_checkpoint}).")
|
||||||
|
|
||||||
|
if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
|
||||||
|
config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
|
||||||
|
checkpoint_version = config.transformers_version
|
||||||
|
if checkpoint_version is not None and checkpoint_version != __version__:
|
||||||
|
logger.warn(
|
||||||
|
f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
|
||||||
|
f"Transformers but your current version is {__version__}. This is not recommended and could "
|
||||||
|
"yield to errors or unwanted behaviors."
|
||||||
|
)
|
||||||
|
|
||||||
if self.deepspeed:
|
if self.deepspeed:
|
||||||
# will be resumed in deepspeed_init
|
# will be resumed in deepspeed_init
|
||||||
pass
|
pass
|
||||||
elif isinstance(self.model, PreTrainedModel):
|
|
||||||
self.model = self.model.from_pretrained(resume_from_checkpoint)
|
|
||||||
model_reloaded = True
|
|
||||||
else:
|
else:
|
||||||
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME))
|
# We load the model state dict on the CPU to avoid an OOM error.
|
||||||
|
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
|
||||||
|
# If the model is on the GPU, it still works!
|
||||||
self.model.load_state_dict(state_dict)
|
self.model.load_state_dict(state_dict)
|
||||||
|
|
||||||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||||
@@ -1293,13 +1305,10 @@ class Trainer:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
|
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
|
||||||
)
|
)
|
||||||
if isinstance(self.model, PreTrainedModel):
|
# We load the model state dict on the CPU to avoid an OOM error.
|
||||||
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
|
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME), map_location="cpu")
|
||||||
if self.place_model_on_device:
|
# If the model is on the GPU, it still works!
|
||||||
self.model = self.model.to(args.device)
|
self.model.load_state_dict(state_dict)
|
||||||
else:
|
|
||||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
|
||||||
self.model.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
if self.deepspeed:
|
if self.deepspeed:
|
||||||
self.deepspeed.load_checkpoint(
|
self.deepspeed.load_checkpoint(
|
||||||
|
|||||||
@@ -725,6 +725,46 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertEqual(b, b1)
|
self.assertEqual(b, b1)
|
||||||
self.check_trainer_state_are_the_same(state, state1)
|
self.check_trainer_state_are_the_same(state, state1)
|
||||||
|
|
||||||
|
def test_resume_training_with_frozen_params(self):
|
||||||
|
if torch.cuda.device_count() > 2:
|
||||||
|
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||||
|
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||||
|
# won't be the same since the training dataloader is shuffled).
|
||||||
|
return
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
output_dir=tmpdir,
|
||||||
|
train_len=128,
|
||||||
|
per_device_train_batch_size=4,
|
||||||
|
save_steps=5,
|
||||||
|
learning_rate=0.1,
|
||||||
|
)
|
||||||
|
trainer.model.a.requires_grad_(False)
|
||||||
|
trainer.train()
|
||||||
|
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
|
state = dataclasses.asdict(trainer.state)
|
||||||
|
|
||||||
|
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||||
|
|
||||||
|
# Reinitialize trainer
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
output_dir=tmpdir,
|
||||||
|
train_len=128,
|
||||||
|
per_device_train_batch_size=4,
|
||||||
|
save_steps=5,
|
||||||
|
learning_rate=0.1,
|
||||||
|
)
|
||||||
|
trainer.model.a.requires_grad_(False)
|
||||||
|
|
||||||
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
|
|
||||||
|
self.assertFalse(trainer.model.a.requires_grad)
|
||||||
|
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||||
|
state1 = dataclasses.asdict(trainer.state)
|
||||||
|
self.assertEqual(a, a1)
|
||||||
|
self.assertEqual(b, b1)
|
||||||
|
self.check_trainer_state_are_the_same(state, state1)
|
||||||
|
|
||||||
def test_load_best_model_at_end(self):
|
def test_load_best_model_at_end(self):
|
||||||
total = int(self.n_epochs * 64 / self.batch_size)
|
total = int(self.n_epochs * 64 / self.batch_size)
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
|||||||
Reference in New Issue
Block a user