Trainer._load_from_checkpoint - support loading multiple Peft adapters (#30505)
* Trainer: load checkpoint model with multiple adapters * Trainer._load_from_checkpoint support multiple active adapters * PeftModel.set_adapter does not support multiple adapters yet * Trainer._load_from_checkpoint test multiple adapters --------- Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
This commit is contained in:
@@ -2413,6 +2413,20 @@ class Trainer:
|
||||
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
|
||||
or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
|
||||
)
|
||||
# if multiple adapters exist, they get saved in sub directories
|
||||
adapter_subdirs = (
|
||||
[
|
||||
folder_name
|
||||
for folder_name in os.listdir(resume_from_checkpoint)
|
||||
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
|
||||
and (
|
||||
os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME))
|
||||
or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME))
|
||||
)
|
||||
]
|
||||
if os.path.isdir(resume_from_checkpoint)
|
||||
else []
|
||||
)
|
||||
|
||||
if is_fsdp_ckpt and not self.is_fsdp_enabled:
|
||||
raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP")
|
||||
@@ -2430,6 +2444,7 @@ class Trainer:
|
||||
]
|
||||
)
|
||||
or is_fsdp_ckpt
|
||||
or adapter_subdirs
|
||||
):
|
||||
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
|
||||
|
||||
@@ -2503,6 +2518,13 @@ class Trainer:
|
||||
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
|
||||
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
|
||||
if os.path.exists(resume_from_checkpoint):
|
||||
if adapter_subdirs:
|
||||
active_adapter = model.active_adapter
|
||||
for subdir_name in adapter_subdirs:
|
||||
peft_id = os.path.join(resume_from_checkpoint, subdir_name)
|
||||
model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter))
|
||||
model.set_adapter(active_adapter)
|
||||
else:
|
||||
model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True)
|
||||
else:
|
||||
logger.warning(
|
||||
|
||||
@@ -964,6 +964,63 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = Trainer(tiny_model, args, train_dataset=train_dataset) # noqa
|
||||
|
||||
@require_peft
|
||||
def test_multiple_peft_adapters(self):
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
# Tests if resuming from checkpoint works if the model has multiple adapters
|
||||
|
||||
MODEL_ID = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
tiny_model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
tiny_model = get_peft_model(tiny_model, peft_config, "adapter1")
|
||||
tiny_model.add_adapter("adapter2", peft_config)
|
||||
|
||||
train_dataset = LineByLineTextDataset(
|
||||
tokenizer=tokenizer,
|
||||
file_path=PATH_SAMPLE_TEXT,
|
||||
block_size=tokenizer.max_len_single_sentence,
|
||||
)
|
||||
for example in train_dataset.examples:
|
||||
example["labels"] = example["input_ids"]
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
per_device_train_batch_size=1,
|
||||
learning_rate=1e-9,
|
||||
save_steps=5,
|
||||
logging_steps=5,
|
||||
max_steps=10,
|
||||
use_cpu=True,
|
||||
)
|
||||
trainer = Trainer(tiny_model, args, tokenizer=tokenizer, train_dataset=train_dataset)
|
||||
|
||||
trainer.train()
|
||||
parameters = dict(tiny_model.named_parameters())
|
||||
state = dataclasses.asdict(trainer.state)
|
||||
|
||||
# Reinitialize trainer
|
||||
trainer = Trainer(tiny_model, args, tokenizer=tokenizer, train_dataset=train_dataset)
|
||||
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||
|
||||
trainer.train(resume_from_checkpoint=checkpoint)
|
||||
parameters1 = dict(tiny_model.named_parameters())
|
||||
state1 = dataclasses.asdict(trainer.state)
|
||||
self.assertEqual(parameters, parameters1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
@require_bitsandbytes
|
||||
def test_rmsprop_bnb(self):
|
||||
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||
|
||||
Reference in New Issue
Block a user