From ad340908e441246f59462ee4f3450085569e4f8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sangyun=5FLEE=20=28=EC=9D=B4=EC=83=81=EC=9C=A4=29?= Date: Wed, 9 Apr 2025 23:36:52 +0900 Subject: [PATCH] Fix warning message for PEFT models in text-generation pipeline #36783 (#36887) * add peft model in constant * add test * fix formating * make fixup execute * change code * check by self.task * add test * fixup test code * fix minor typo * fix pipeline test * apply maintainers reqests --- src/transformers/pipelines/base.py | 19 ++++++++ .../peft_integration/test_peft_integration.py | 47 +++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 1df663a26e..459fd283ca 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -849,6 +849,22 @@ PIPELINE_INIT_ARGS = build_pipeline_init_args( supports_binary_output=True, ) +SUPPORTED_PEFT_TASKS = { + "document-question-answering": ["PeftModelForQuestionAnswering"], + "feature-extraction": ["PeftModelForFeatureExtraction", "PeftModel"], + "question-answering": ["PeftModelForQuestionAnswering"], + "summarization": ["PeftModelForSeq2SeqLM"], + "table-question-answering": ["PeftModelForQuestionAnswering"], + "text2text-generation": ["PeftModelForSeq2SeqLM"], + "text-classification": ["PeftModelForSequenceClassification"], + "sentiment-analysis": ["PeftModelForSequenceClassification"], + "text-generation": ["PeftModelForCausalLM"], + "token-classification": ["PeftModelForTokenClassification"], + "ner": ["PeftModelForTokenClassification"], + "translation": ["PeftModelForSeq2SeqLM"], + "translation_xx_to_yy": ["PeftModelForSeq2SeqLM"], + "zero-shot-classification": ["PeftModelForSequenceClassification"], +} if is_torch_available(): from transformers.pipelines.pt_utils import ( @@ -1209,6 +1225,9 @@ class Pipeline(_ScikitCompat, PushToHubMixin): """ if not isinstance(supported_models, list): # Create from a model mapping supported_models_names = [] + if self.task in SUPPORTED_PEFT_TASKS: + supported_models_names.extend(SUPPORTED_PEFT_TASKS[self.task]) + for _, model_name in supported_models.items(): # Mapping can now contain tuples of models for the same configuration. if isinstance(model_name, tuple): diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index 5bb16e4241..203124439d 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -814,3 +814,50 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): msg = "When using prompt learning PEFT methods such as PREFIX_TUNING" with self.assertRaisesRegex(RuntimeError, msg): trainer.train() + + def test_peft_pipeline_no_warning(self): + """ + Test to verify that the warning message "The model 'PeftModel' is not supported for text-generation" + does not appear when using PeftModel with text-generation pipeline. + """ + from peft import PeftModel + + from transformers import pipeline + + ADAPTER_PATH = "peft-internal-testing/tiny-OPTForCausalLM-lora" + BASE_PATH = "hf-internal-testing/tiny-random-OPTForCausalLM" + + # Input text for testing + text = "Who is a Elon Musk?" + expected_error_msg = "The model 'PeftModel' is not supported for text-generation" + + model = AutoModelForCausalLM.from_pretrained( + BASE_PATH, + device_map="auto", + ) + tokenizer = AutoTokenizer.from_pretrained(BASE_PATH) + + lora_model = PeftModel.from_pretrained( + model, + ADAPTER_PATH, + device_map="auto", + ) + + # Create pipeline with PEFT model while capturing log output + # Check that the warning message is not present in the logs + pipeline_logger = logging.get_logger("transformers.pipelines.base") + with self.assertNoLogs(pipeline_logger, logging.ERROR) as cl: + lora_generator = pipeline( + task="text-generation", + model=lora_model, + tokenizer=tokenizer, + max_length=10, + ) + + # Generate text to verify pipeline works + _ = lora_generator(text) + + # Check that the warning message is not present in the logs + self.assertNotIn( + expected_error_msg, cl.out, f"Error message '{expected_error_msg}' should not appear when using PeftModel" + )