Load checkpoint without re-creating the model (#11318)
This commit is contained in:
@@ -271,7 +271,7 @@ class PretrainedConfig(object):
|
||||
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
||||
|
||||
# Drop the transformers version info
|
||||
kwargs.pop("transformers_version", None)
|
||||
self.transformers_version = kwargs.pop("transformers_version", None)
|
||||
|
||||
# Additional attributes without default values
|
||||
for key, value in kwargs.items():
|
||||
|
||||
@@ -55,9 +55,12 @@ from torch.utils.data.dataset import Dataset, IterableDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||
|
||||
from . import __version__
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||
from .dependency_versions_check import dep_version_check
|
||||
from .file_utils import (
|
||||
CONFIG_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_apex_available,
|
||||
is_datasets_available,
|
||||
@@ -999,14 +1002,23 @@ class Trainer:
|
||||
|
||||
logger.info(f"Loading model from {resume_from_checkpoint}).")
|
||||
|
||||
if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
|
||||
config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
|
||||
checkpoint_version = config.transformers_version
|
||||
if checkpoint_version is not None and checkpoint_version != __version__:
|
||||
logger.warn(
|
||||
f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
|
||||
f"Transformers but your current version is {__version__}. This is not recommended and could "
|
||||
"yield to errors or unwanted behaviors."
|
||||
)
|
||||
|
||||
if self.deepspeed:
|
||||
# will be resumed in deepspeed_init
|
||||
pass
|
||||
elif isinstance(self.model, PreTrainedModel):
|
||||
self.model = self.model.from_pretrained(resume_from_checkpoint)
|
||||
model_reloaded = True
|
||||
else:
|
||||
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME))
|
||||
# We load the model state dict on the CPU to avoid an OOM error.
|
||||
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
|
||||
# If the model is on the GPU, it still works!
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||
@@ -1293,13 +1305,10 @@ class Trainer:
|
||||
logger.info(
|
||||
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
|
||||
)
|
||||
if isinstance(self.model, PreTrainedModel):
|
||||
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
|
||||
if self.place_model_on_device:
|
||||
self.model = self.model.to(args.device)
|
||||
else:
|
||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
||||
self.model.load_state_dict(state_dict)
|
||||
# We load the model state dict on the CPU to avoid an OOM error.
|
||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME), map_location="cpu")
|
||||
# If the model is on the GPU, it still works!
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
if self.deepspeed:
|
||||
self.deepspeed.load_checkpoint(
|
||||
|
||||
@@ -725,6 +725,46 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
def test_resume_training_with_frozen_params(self):
|
||||
if torch.cuda.device_count() > 2:
|
||||
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
|
||||
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
|
||||
# won't be the same since the training dataloader is shuffled).
|
||||
return
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
train_len=128,
|
||||
per_device_train_batch_size=4,
|
||||
save_steps=5,
|
||||
learning_rate=0.1,
|
||||
)
|
||||
trainer.model.a.requires_grad_(False)
|
||||
trainer.train()
|
||||
(a, b) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state = dataclasses.asdict(trainer.state)
|
||||
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||
|
||||
# Reinitialize trainer
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
train_len=128,
|
||||
per_device_train_batch_size=4,
|
||||
save_steps=5,
|
||||
learning_rate=0.1,
|
||||
)
|
||||
trainer.model.a.requires_grad_(False)
|
||||
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
|
||||
self.assertFalse(trainer.model.a.requires_grad)
|
||||
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(a, a1)
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
def test_load_best_model_at_end(self):
|
||||
total = int(self.n_epochs * 64 / self.batch_size)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
|
||||
Reference in New Issue
Block a user