From bcc50cc7ce12c458ef9e82e189652efe9150a4d0 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 11 Dec 2024 12:44:39 +0100 Subject: [PATCH] [PEFT] Better Trainer error when prompt learning with loading best model at the end (#35087) Original issue: https://github.com/huggingface/peft/issues/2256 There is a potential error when using load_best_model_at_end=True with a prompt learning PEFT method. This is because Trainer uses load_adapter under the hood but with some prompt learning methods, there is an optimization on the saved model to remove parameters that are not required for inference, which in turn requires a change to the model architecture. This is why load_adapter will fail in such cases and users should instead set load_best_model_at_end=False and use PeftModel.from_pretrained. As this is not obvious, we now intercept the error and add a helpful error message. --- src/transformers/trainer.py | 17 +++- .../peft_integration/test_peft_integration.py | 84 ++++++++++++++++++- 2 files changed, 99 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index be41a415e5..a708d8deb4 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2938,7 +2938,22 @@ class Trainer: active_adapter = model.active_adapter if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): - model.load_adapter(self.state.best_model_checkpoint, active_adapter) + try: + model.load_adapter(self.state.best_model_checkpoint, active_adapter) + except RuntimeError as exc: + if model.peft_config[active_adapter].is_prompt_learning: + # for context: https://github.com/huggingface/peft/issues/2256 + msg = ( + "When using prompt learning PEFT methods such as " + f"{model.peft_config[active_adapter].peft_type.value}, setting " + "load_best_model_at_end=True can lead to errors, it is recommended " + "to set this to False and to load the model manually from the checkpoint " + "directory using PeftModel.from_pretrained(base_model, ) after training " + "has finished." + ) + raise RuntimeError(msg) from exc + else: + raise # Load_adapter has no return value present, modify it when appropriate. from torch.nn.modules.module import _IncompatibleKeys diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index 48fd6da3d6..bdbccee5ad 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -17,10 +17,19 @@ import os import tempfile import unittest +from datasets import Dataset, DatasetDict from huggingface_hub import hf_hub_download from packaging import version -from transformers import AutoModelForCausalLM, OPTForCausalLM, logging +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + OPTForCausalLM, + Trainer, + TrainingArguments, + logging, +) from transformers.testing_utils import ( CaptureLogger, require_bitsandbytes, @@ -665,3 +674,76 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): else: assert not module.training assert all(not p.requires_grad for p in module.parameters()) + + def test_prefix_tuning_trainer_load_best_model_at_end_error(self): + # Original issue: https://github.com/huggingface/peft/issues/2256 + # There is a potential error when using load_best_model_at_end=True with a prompt learning PEFT method. This is + # because Trainer uses load_adapter under the hood but with some prompt learning methods, there is an + # optimization on the saved model to remove parameters that are not required for inference, which in turn + # requires a change to the model architecture. This is why load_adapter will fail in such cases and users should + # instead set load_best_model_at_end=False and use PeftModel.from_pretrained. As this is not obvious, we now + # intercept the error and add a helpful error message. + # This test checks this error message. It also tests the "happy path" (i.e. no error) when using LoRA. + from peft import LoraConfig, PrefixTuningConfig, TaskType, get_peft_model + + # create a small sequence classification dataset (binary classification) + dataset = [] + for i, row in enumerate(os.__doc__.splitlines()): + dataset.append({"text": row, "label": i % 2}) + ds_train = Dataset.from_list(dataset) + ds_valid = ds_train + datasets = DatasetDict( + { + "train": ds_train, + "val": ds_valid, + } + ) + + # tokenizer for peft-internal-testing/tiny-OPTForCausalLM-lora cannot be loaded, thus using + # hf-internal-testing/tiny-random-OPTForCausalLM + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left", model_type="opt") + + def tokenize_function(examples): + return tokenizer(examples["text"], max_length=128, truncation=True, padding="max_length") + + tokenized_datasets = datasets.map(tokenize_function, batched=True) + # lora works, prefix-tuning is expected to raise an error + peft_configs = { + "lora": LoraConfig(task_type=TaskType.SEQ_CLS), + "prefix-tuning": PrefixTuningConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + prefix_projection=True, + num_virtual_tokens=10, + ), + } + + for peft_type, peft_config in peft_configs.items(): + base_model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=2) + base_model.config.pad_token_id = tokenizer.pad_token_id + peft_model = get_peft_model(base_model, peft_config) + + with tempfile.TemporaryDirectory() as tmpdirname: + training_args = TrainingArguments( + output_dir=tmpdirname, + num_train_epochs=3, + eval_strategy="epoch", + save_strategy="epoch", + load_best_model_at_end=True, + ) + trainer = Trainer( + model=peft_model, + args=training_args, + train_dataset=tokenized_datasets["train"], + eval_dataset=tokenized_datasets["val"], + ) + + if peft_type == "lora": + # LoRA works with load_best_model_at_end + trainer.train() + else: + # prefix tuning does not work, but at least users should get a helpful error message + msg = "When using prompt learning PEFT methods such as PREFIX_TUNING" + with self.assertRaisesRegex(RuntimeError, msg): + trainer.train()