* add peft model in constant * add test * fix formating * make fixup execute * change code * check by self.task * add test * fixup test code * fix minor typo * fix pipeline test * apply maintainers reqests
This commit is contained in:
committed by
GitHub
parent
2527f71a47
commit
ad340908e4
@@ -849,6 +849,22 @@ PIPELINE_INIT_ARGS = build_pipeline_init_args(
|
||||
supports_binary_output=True,
|
||||
)
|
||||
|
||||
SUPPORTED_PEFT_TASKS = {
|
||||
"document-question-answering": ["PeftModelForQuestionAnswering"],
|
||||
"feature-extraction": ["PeftModelForFeatureExtraction", "PeftModel"],
|
||||
"question-answering": ["PeftModelForQuestionAnswering"],
|
||||
"summarization": ["PeftModelForSeq2SeqLM"],
|
||||
"table-question-answering": ["PeftModelForQuestionAnswering"],
|
||||
"text2text-generation": ["PeftModelForSeq2SeqLM"],
|
||||
"text-classification": ["PeftModelForSequenceClassification"],
|
||||
"sentiment-analysis": ["PeftModelForSequenceClassification"],
|
||||
"text-generation": ["PeftModelForCausalLM"],
|
||||
"token-classification": ["PeftModelForTokenClassification"],
|
||||
"ner": ["PeftModelForTokenClassification"],
|
||||
"translation": ["PeftModelForSeq2SeqLM"],
|
||||
"translation_xx_to_yy": ["PeftModelForSeq2SeqLM"],
|
||||
"zero-shot-classification": ["PeftModelForSequenceClassification"],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.pipelines.pt_utils import (
|
||||
@@ -1209,6 +1225,9 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
||||
"""
|
||||
if not isinstance(supported_models, list): # Create from a model mapping
|
||||
supported_models_names = []
|
||||
if self.task in SUPPORTED_PEFT_TASKS:
|
||||
supported_models_names.extend(SUPPORTED_PEFT_TASKS[self.task])
|
||||
|
||||
for _, model_name in supported_models.items():
|
||||
# Mapping can now contain tuples of models for the same configuration.
|
||||
if isinstance(model_name, tuple):
|
||||
|
||||
@@ -814,3 +814,50 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
msg = "When using prompt learning PEFT methods such as PREFIX_TUNING"
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
trainer.train()
|
||||
|
||||
def test_peft_pipeline_no_warning(self):
|
||||
"""
|
||||
Test to verify that the warning message "The model 'PeftModel' is not supported for text-generation"
|
||||
does not appear when using PeftModel with text-generation pipeline.
|
||||
"""
|
||||
from peft import PeftModel
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
ADAPTER_PATH = "peft-internal-testing/tiny-OPTForCausalLM-lora"
|
||||
BASE_PATH = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
||||
|
||||
# Input text for testing
|
||||
text = "Who is a Elon Musk?"
|
||||
expected_error_msg = "The model 'PeftModel' is not supported for text-generation"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
BASE_PATH,
|
||||
device_map="auto",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(BASE_PATH)
|
||||
|
||||
lora_model = PeftModel.from_pretrained(
|
||||
model,
|
||||
ADAPTER_PATH,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
# Create pipeline with PEFT model while capturing log output
|
||||
# Check that the warning message is not present in the logs
|
||||
pipeline_logger = logging.get_logger("transformers.pipelines.base")
|
||||
with self.assertNoLogs(pipeline_logger, logging.ERROR) as cl:
|
||||
lora_generator = pipeline(
|
||||
task="text-generation",
|
||||
model=lora_model,
|
||||
tokenizer=tokenizer,
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
# Generate text to verify pipeline works
|
||||
_ = lora_generator(text)
|
||||
|
||||
# Check that the warning message is not present in the logs
|
||||
self.assertNotIn(
|
||||
expected_error_msg, cl.out, f"Error message '{expected_error_msg}' should not appear when using PeftModel"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user