Fix warning message for PEFT models in text-generation pipeline #36783 (#36887)

* add peft model in constant

* add test

* fix formating

* make fixup execute

* change code

* check by self.task

* add test

* fixup test code

* fix minor typo

* fix pipeline test

* apply maintainers reqests
This commit is contained in:
Sangyun_LEE (이상윤)
2025-04-09 23:36:52 +09:00
committed by GitHub
parent 2527f71a47
commit ad340908e4
2 changed files with 66 additions and 0 deletions

View File

@@ -849,6 +849,22 @@ PIPELINE_INIT_ARGS = build_pipeline_init_args(
supports_binary_output=True, supports_binary_output=True,
) )
SUPPORTED_PEFT_TASKS = {
"document-question-answering": ["PeftModelForQuestionAnswering"],
"feature-extraction": ["PeftModelForFeatureExtraction", "PeftModel"],
"question-answering": ["PeftModelForQuestionAnswering"],
"summarization": ["PeftModelForSeq2SeqLM"],
"table-question-answering": ["PeftModelForQuestionAnswering"],
"text2text-generation": ["PeftModelForSeq2SeqLM"],
"text-classification": ["PeftModelForSequenceClassification"],
"sentiment-analysis": ["PeftModelForSequenceClassification"],
"text-generation": ["PeftModelForCausalLM"],
"token-classification": ["PeftModelForTokenClassification"],
"ner": ["PeftModelForTokenClassification"],
"translation": ["PeftModelForSeq2SeqLM"],
"translation_xx_to_yy": ["PeftModelForSeq2SeqLM"],
"zero-shot-classification": ["PeftModelForSequenceClassification"],
}
if is_torch_available(): if is_torch_available():
from transformers.pipelines.pt_utils import ( from transformers.pipelines.pt_utils import (
@@ -1209,6 +1225,9 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
""" """
if not isinstance(supported_models, list): # Create from a model mapping if not isinstance(supported_models, list): # Create from a model mapping
supported_models_names = [] supported_models_names = []
if self.task in SUPPORTED_PEFT_TASKS:
supported_models_names.extend(SUPPORTED_PEFT_TASKS[self.task])
for _, model_name in supported_models.items(): for _, model_name in supported_models.items():
# Mapping can now contain tuples of models for the same configuration. # Mapping can now contain tuples of models for the same configuration.
if isinstance(model_name, tuple): if isinstance(model_name, tuple):

View File

@@ -814,3 +814,50 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
msg = "When using prompt learning PEFT methods such as PREFIX_TUNING" msg = "When using prompt learning PEFT methods such as PREFIX_TUNING"
with self.assertRaisesRegex(RuntimeError, msg): with self.assertRaisesRegex(RuntimeError, msg):
trainer.train() trainer.train()
def test_peft_pipeline_no_warning(self):
"""
Test to verify that the warning message "The model 'PeftModel' is not supported for text-generation"
does not appear when using PeftModel with text-generation pipeline.
"""
from peft import PeftModel
from transformers import pipeline
ADAPTER_PATH = "peft-internal-testing/tiny-OPTForCausalLM-lora"
BASE_PATH = "hf-internal-testing/tiny-random-OPTForCausalLM"
# Input text for testing
text = "Who is a Elon Musk?"
expected_error_msg = "The model 'PeftModel' is not supported for text-generation"
model = AutoModelForCausalLM.from_pretrained(
BASE_PATH,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(BASE_PATH)
lora_model = PeftModel.from_pretrained(
model,
ADAPTER_PATH,
device_map="auto",
)
# Create pipeline with PEFT model while capturing log output
# Check that the warning message is not present in the logs
pipeline_logger = logging.get_logger("transformers.pipelines.base")
with self.assertNoLogs(pipeline_logger, logging.ERROR) as cl:
lora_generator = pipeline(
task="text-generation",
model=lora_model,
tokenizer=tokenizer,
max_length=10,
)
# Generate text to verify pipeline works
_ = lora_generator(text)
# Check that the warning message is not present in the logs
self.assertNotIn(
expected_error_msg, cl.out, f"Error message '{expected_error_msg}' should not appear when using PeftModel"
)