From e076953079469983b8430fdb032f8a61dc2ae5f4 Mon Sep 17 00:00:00 2001 From: Clara Pohland <54847419+claralp@users.noreply.github.com> Date: Mon, 6 May 2024 14:22:52 +0200 Subject: [PATCH] Trainer._load_from_checkpoint - support loading multiple Peft adapters (#30505) * Trainer: load checkpoint model with multiple adapters * Trainer._load_from_checkpoint support multiple active adapters * PeftModel.set_adapter does not support multiple adapters yet * Trainer._load_from_checkpoint test multiple adapters --------- Co-authored-by: Clara Luise Pohland --- src/transformers/trainer.py | 24 ++++++++++++++- tests/trainer/test_trainer.py | 57 +++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d967c9314b..c18404dfcd 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2413,6 +2413,20 @@ class Trainer: # this checks the FSDP state dict when `FULL_STATE_DICT` is used or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin")) ) + # if multiple adapters exist, they get saved in sub directories + adapter_subdirs = ( + [ + folder_name + for folder_name in os.listdir(resume_from_checkpoint) + if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) + and ( + os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME)) + or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME)) + ) + ] + if os.path.isdir(resume_from_checkpoint) + else [] + ) if is_fsdp_ckpt and not self.is_fsdp_enabled: raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP") @@ -2430,6 +2444,7 @@ class Trainer: ] ) or is_fsdp_ckpt + or adapter_subdirs ): raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") @@ -2503,7 +2518,14 @@ class Trainer: # If train a model using PEFT & LoRA, assume that adapter have been saved properly. if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): if os.path.exists(resume_from_checkpoint): - model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True) + if adapter_subdirs: + active_adapter = model.active_adapter + for subdir_name in adapter_subdirs: + peft_id = os.path.join(resume_from_checkpoint, subdir_name) + model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter)) + model.set_adapter(active_adapter) + else: + model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True) else: logger.warning( "The intermediate checkpoints of PEFT may not be saved correctly, " diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 8913de4db1..89b26221f3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -964,6 +964,63 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): with self.assertRaises(ValueError): _ = Trainer(tiny_model, args, train_dataset=train_dataset) # noqa + @require_peft + def test_multiple_peft_adapters(self): + from peft import LoraConfig, get_peft_model + + # Tests if resuming from checkpoint works if the model has multiple adapters + + MODEL_ID = "hf-internal-testing/tiny-random-LlamaForCausalLM" + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + tiny_model = AutoModelForCausalLM.from_pretrained(MODEL_ID) + + peft_config = LoraConfig( + r=4, + lora_alpha=16, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + tiny_model = get_peft_model(tiny_model, peft_config, "adapter1") + tiny_model.add_adapter("adapter2", peft_config) + + train_dataset = LineByLineTextDataset( + tokenizer=tokenizer, + file_path=PATH_SAMPLE_TEXT, + block_size=tokenizer.max_len_single_sentence, + ) + for example in train_dataset.examples: + example["labels"] = example["input_ids"] + + tokenizer.pad_token = tokenizer.eos_token + + with tempfile.TemporaryDirectory() as tmpdir: + args = TrainingArguments( + tmpdir, + per_device_train_batch_size=1, + learning_rate=1e-9, + save_steps=5, + logging_steps=5, + max_steps=10, + use_cpu=True, + ) + trainer = Trainer(tiny_model, args, tokenizer=tokenizer, train_dataset=train_dataset) + + trainer.train() + parameters = dict(tiny_model.named_parameters()) + state = dataclasses.asdict(trainer.state) + + # Reinitialize trainer + trainer = Trainer(tiny_model, args, tokenizer=tokenizer, train_dataset=train_dataset) + + checkpoint = os.path.join(tmpdir, "checkpoint-5") + + trainer.train(resume_from_checkpoint=checkpoint) + parameters1 = dict(tiny_model.named_parameters()) + state1 = dataclasses.asdict(trainer.state) + self.assertEqual(parameters, parameters1) + self.check_trainer_state_are_the_same(state, state1) + @require_bitsandbytes def test_rmsprop_bnb(self): config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)