Fix warning message for PEFT models in text-generation pipeline #36783 (#36887)

* add peft model in constant

* add test

* fix formating

* make fixup execute

* change code

* check by self.task

* add test

* fixup test code

* fix minor typo

* fix pipeline test

* apply maintainers reqests
This commit is contained in:
Sangyun_LEE (이상윤)
2025-04-09 23:36:52 +09:00
committed by GitHub
parent 2527f71a47
commit ad340908e4
2 changed files with 66 additions and 0 deletions

View File

@@ -814,3 +814,50 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
msg = "When using prompt learning PEFT methods such as PREFIX_TUNING"
with self.assertRaisesRegex(RuntimeError, msg):
trainer.train()
def test_peft_pipeline_no_warning(self):
"""
Test to verify that the warning message "The model 'PeftModel' is not supported for text-generation"
does not appear when using PeftModel with text-generation pipeline.
"""
from peft import PeftModel
from transformers import pipeline
ADAPTER_PATH = "peft-internal-testing/tiny-OPTForCausalLM-lora"
BASE_PATH = "hf-internal-testing/tiny-random-OPTForCausalLM"
# Input text for testing
text = "Who is a Elon Musk?"
expected_error_msg = "The model 'PeftModel' is not supported for text-generation"
model = AutoModelForCausalLM.from_pretrained(
BASE_PATH,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(BASE_PATH)
lora_model = PeftModel.from_pretrained(
model,
ADAPTER_PATH,
device_map="auto",
)
# Create pipeline with PEFT model while capturing log output
# Check that the warning message is not present in the logs
pipeline_logger = logging.get_logger("transformers.pipelines.base")
with self.assertNoLogs(pipeline_logger, logging.ERROR) as cl:
lora_generator = pipeline(
task="text-generation",
model=lora_model,
tokenizer=tokenizer,
max_length=10,
)
# Generate text to verify pipeline works
_ = lora_generator(text)
# Check that the warning message is not present in the logs
self.assertNotIn(
expected_error_msg, cl.out, f"Error message '{expected_error_msg}' should not appear when using PeftModel"
)