Trainer._load_from_checkpoint - support loading multiple Peft adapters (#30505)
* Trainer: load checkpoint model with multiple adapters * Trainer._load_from_checkpoint support multiple active adapters * PeftModel.set_adapter does not support multiple adapters yet * Trainer._load_from_checkpoint test multiple adapters --------- Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
This commit is contained in:
@@ -2413,6 +2413,20 @@ class Trainer:
|
|||||||
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
|
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
|
||||||
or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
|
or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
|
||||||
)
|
)
|
||||||
|
# if multiple adapters exist, they get saved in sub directories
|
||||||
|
adapter_subdirs = (
|
||||||
|
[
|
||||||
|
folder_name
|
||||||
|
for folder_name in os.listdir(resume_from_checkpoint)
|
||||||
|
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
|
||||||
|
and (
|
||||||
|
os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME))
|
||||||
|
or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME))
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if os.path.isdir(resume_from_checkpoint)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
if is_fsdp_ckpt and not self.is_fsdp_enabled:
|
if is_fsdp_ckpt and not self.is_fsdp_enabled:
|
||||||
raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP")
|
raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP")
|
||||||
@@ -2430,6 +2444,7 @@ class Trainer:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
or is_fsdp_ckpt
|
or is_fsdp_ckpt
|
||||||
|
or adapter_subdirs
|
||||||
):
|
):
|
||||||
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
|
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
|
||||||
|
|
||||||
@@ -2503,7 +2518,14 @@ class Trainer:
|
|||||||
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
|
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
|
||||||
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
|
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
|
||||||
if os.path.exists(resume_from_checkpoint):
|
if os.path.exists(resume_from_checkpoint):
|
||||||
model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True)
|
if adapter_subdirs:
|
||||||
|
active_adapter = model.active_adapter
|
||||||
|
for subdir_name in adapter_subdirs:
|
||||||
|
peft_id = os.path.join(resume_from_checkpoint, subdir_name)
|
||||||
|
model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter))
|
||||||
|
model.set_adapter(active_adapter)
|
||||||
|
else:
|
||||||
|
model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The intermediate checkpoints of PEFT may not be saved correctly, "
|
"The intermediate checkpoints of PEFT may not be saved correctly, "
|
||||||
|
|||||||
@@ -964,6 +964,63 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
_ = Trainer(tiny_model, args, train_dataset=train_dataset) # noqa
|
_ = Trainer(tiny_model, args, train_dataset=train_dataset) # noqa
|
||||||
|
|
||||||
|
@require_peft
|
||||||
|
def test_multiple_peft_adapters(self):
|
||||||
|
from peft import LoraConfig, get_peft_model
|
||||||
|
|
||||||
|
# Tests if resuming from checkpoint works if the model has multiple adapters
|
||||||
|
|
||||||
|
MODEL_ID = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||||
|
tiny_model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
|
||||||
|
|
||||||
|
peft_config = LoraConfig(
|
||||||
|
r=4,
|
||||||
|
lora_alpha=16,
|
||||||
|
lora_dropout=0.05,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
tiny_model = get_peft_model(tiny_model, peft_config, "adapter1")
|
||||||
|
tiny_model.add_adapter("adapter2", peft_config)
|
||||||
|
|
||||||
|
train_dataset = LineByLineTextDataset(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
file_path=PATH_SAMPLE_TEXT,
|
||||||
|
block_size=tokenizer.max_len_single_sentence,
|
||||||
|
)
|
||||||
|
for example in train_dataset.examples:
|
||||||
|
example["labels"] = example["input_ids"]
|
||||||
|
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmpdir,
|
||||||
|
per_device_train_batch_size=1,
|
||||||
|
learning_rate=1e-9,
|
||||||
|
save_steps=5,
|
||||||
|
logging_steps=5,
|
||||||
|
max_steps=10,
|
||||||
|
use_cpu=True,
|
||||||
|
)
|
||||||
|
trainer = Trainer(tiny_model, args, tokenizer=tokenizer, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
parameters = dict(tiny_model.named_parameters())
|
||||||
|
state = dataclasses.asdict(trainer.state)
|
||||||
|
|
||||||
|
# Reinitialize trainer
|
||||||
|
trainer = Trainer(tiny_model, args, tokenizer=tokenizer, train_dataset=train_dataset)
|
||||||
|
|
||||||
|
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||||
|
|
||||||
|
trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
|
parameters1 = dict(tiny_model.named_parameters())
|
||||||
|
state1 = dataclasses.asdict(trainer.state)
|
||||||
|
self.assertEqual(parameters, parameters1)
|
||||||
|
self.check_trainer_state_are_the_same(state, state1)
|
||||||
|
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
def test_rmsprop_bnb(self):
|
def test_rmsprop_bnb(self):
|
||||||
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||||
|
|||||||
Reference in New Issue
Block a user