[PEFT] Better Trainer error when prompt learning with loading best model at the end (#35087)
Original issue: https://github.com/huggingface/peft/issues/2256 There is a potential error when using load_best_model_at_end=True with a prompt learning PEFT method. This is because Trainer uses load_adapter under the hood but with some prompt learning methods, there is an optimization on the saved model to remove parameters that are not required for inference, which in turn requires a change to the model architecture. This is why load_adapter will fail in such cases and users should instead set load_best_model_at_end=False and use PeftModel.from_pretrained. As this is not obvious, we now intercept the error and add a helpful error message.
This commit is contained in:
@@ -2938,7 +2938,22 @@ class Trainer:
|
|||||||
active_adapter = model.active_adapter
|
active_adapter = model.active_adapter
|
||||||
|
|
||||||
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
|
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
|
||||||
|
try:
|
||||||
model.load_adapter(self.state.best_model_checkpoint, active_adapter)
|
model.load_adapter(self.state.best_model_checkpoint, active_adapter)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
if model.peft_config[active_adapter].is_prompt_learning:
|
||||||
|
# for context: https://github.com/huggingface/peft/issues/2256
|
||||||
|
msg = (
|
||||||
|
"When using prompt learning PEFT methods such as "
|
||||||
|
f"{model.peft_config[active_adapter].peft_type.value}, setting "
|
||||||
|
"load_best_model_at_end=True can lead to errors, it is recommended "
|
||||||
|
"to set this to False and to load the model manually from the checkpoint "
|
||||||
|
"directory using PeftModel.from_pretrained(base_model, <path>) after training "
|
||||||
|
"has finished."
|
||||||
|
)
|
||||||
|
raise RuntimeError(msg) from exc
|
||||||
|
else:
|
||||||
|
raise
|
||||||
# Load_adapter has no return value present, modify it when appropriate.
|
# Load_adapter has no return value present, modify it when appropriate.
|
||||||
from torch.nn.modules.module import _IncompatibleKeys
|
from torch.nn.modules.module import _IncompatibleKeys
|
||||||
|
|
||||||
|
|||||||
@@ -17,10 +17,19 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from datasets import Dataset, DatasetDict
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, OPTForCausalLM, logging
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForSequenceClassification,
|
||||||
|
AutoTokenizer,
|
||||||
|
OPTForCausalLM,
|
||||||
|
Trainer,
|
||||||
|
TrainingArguments,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
@@ -665,3 +674,76 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
else:
|
else:
|
||||||
assert not module.training
|
assert not module.training
|
||||||
assert all(not p.requires_grad for p in module.parameters())
|
assert all(not p.requires_grad for p in module.parameters())
|
||||||
|
|
||||||
|
def test_prefix_tuning_trainer_load_best_model_at_end_error(self):
|
||||||
|
# Original issue: https://github.com/huggingface/peft/issues/2256
|
||||||
|
# There is a potential error when using load_best_model_at_end=True with a prompt learning PEFT method. This is
|
||||||
|
# because Trainer uses load_adapter under the hood but with some prompt learning methods, there is an
|
||||||
|
# optimization on the saved model to remove parameters that are not required for inference, which in turn
|
||||||
|
# requires a change to the model architecture. This is why load_adapter will fail in such cases and users should
|
||||||
|
# instead set load_best_model_at_end=False and use PeftModel.from_pretrained. As this is not obvious, we now
|
||||||
|
# intercept the error and add a helpful error message.
|
||||||
|
# This test checks this error message. It also tests the "happy path" (i.e. no error) when using LoRA.
|
||||||
|
from peft import LoraConfig, PrefixTuningConfig, TaskType, get_peft_model
|
||||||
|
|
||||||
|
# create a small sequence classification dataset (binary classification)
|
||||||
|
dataset = []
|
||||||
|
for i, row in enumerate(os.__doc__.splitlines()):
|
||||||
|
dataset.append({"text": row, "label": i % 2})
|
||||||
|
ds_train = Dataset.from_list(dataset)
|
||||||
|
ds_valid = ds_train
|
||||||
|
datasets = DatasetDict(
|
||||||
|
{
|
||||||
|
"train": ds_train,
|
||||||
|
"val": ds_valid,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# tokenizer for peft-internal-testing/tiny-OPTForCausalLM-lora cannot be loaded, thus using
|
||||||
|
# hf-internal-testing/tiny-random-OPTForCausalLM
|
||||||
|
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left", model_type="opt")
|
||||||
|
|
||||||
|
def tokenize_function(examples):
|
||||||
|
return tokenizer(examples["text"], max_length=128, truncation=True, padding="max_length")
|
||||||
|
|
||||||
|
tokenized_datasets = datasets.map(tokenize_function, batched=True)
|
||||||
|
# lora works, prefix-tuning is expected to raise an error
|
||||||
|
peft_configs = {
|
||||||
|
"lora": LoraConfig(task_type=TaskType.SEQ_CLS),
|
||||||
|
"prefix-tuning": PrefixTuningConfig(
|
||||||
|
task_type=TaskType.SEQ_CLS,
|
||||||
|
inference_mode=False,
|
||||||
|
prefix_projection=True,
|
||||||
|
num_virtual_tokens=10,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
for peft_type, peft_config in peft_configs.items():
|
||||||
|
base_model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=2)
|
||||||
|
base_model.config.pad_token_id = tokenizer.pad_token_id
|
||||||
|
peft_model = get_peft_model(base_model, peft_config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=tmpdirname,
|
||||||
|
num_train_epochs=3,
|
||||||
|
eval_strategy="epoch",
|
||||||
|
save_strategy="epoch",
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
)
|
||||||
|
trainer = Trainer(
|
||||||
|
model=peft_model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=tokenized_datasets["train"],
|
||||||
|
eval_dataset=tokenized_datasets["val"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if peft_type == "lora":
|
||||||
|
# LoRA works with load_best_model_at_end
|
||||||
|
trainer.train()
|
||||||
|
else:
|
||||||
|
# prefix tuning does not work, but at least users should get a helpful error message
|
||||||
|
msg = "When using prompt learning PEFT methods such as PREFIX_TUNING"
|
||||||
|
with self.assertRaisesRegex(RuntimeError, msg):
|
||||||
|
trainer.train()
|
||||||
|
|||||||
Reference in New Issue
Block a user