* add peft model in constant * add test * fix formating * make fixup execute * change code * check by self.task * add test * fixup test code * fix minor typo * fix pipeline test * apply maintainers reqests
This commit is contained in:
committed by
GitHub
parent
2527f71a47
commit
ad340908e4
@@ -849,6 +849,22 @@ PIPELINE_INIT_ARGS = build_pipeline_init_args(
|
|||||||
supports_binary_output=True,
|
supports_binary_output=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
SUPPORTED_PEFT_TASKS = {
|
||||||
|
"document-question-answering": ["PeftModelForQuestionAnswering"],
|
||||||
|
"feature-extraction": ["PeftModelForFeatureExtraction", "PeftModel"],
|
||||||
|
"question-answering": ["PeftModelForQuestionAnswering"],
|
||||||
|
"summarization": ["PeftModelForSeq2SeqLM"],
|
||||||
|
"table-question-answering": ["PeftModelForQuestionAnswering"],
|
||||||
|
"text2text-generation": ["PeftModelForSeq2SeqLM"],
|
||||||
|
"text-classification": ["PeftModelForSequenceClassification"],
|
||||||
|
"sentiment-analysis": ["PeftModelForSequenceClassification"],
|
||||||
|
"text-generation": ["PeftModelForCausalLM"],
|
||||||
|
"token-classification": ["PeftModelForTokenClassification"],
|
||||||
|
"ner": ["PeftModelForTokenClassification"],
|
||||||
|
"translation": ["PeftModelForSeq2SeqLM"],
|
||||||
|
"translation_xx_to_yy": ["PeftModelForSeq2SeqLM"],
|
||||||
|
"zero-shot-classification": ["PeftModelForSequenceClassification"],
|
||||||
|
}
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers.pipelines.pt_utils import (
|
from transformers.pipelines.pt_utils import (
|
||||||
@@ -1209,6 +1225,9 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
|||||||
"""
|
"""
|
||||||
if not isinstance(supported_models, list): # Create from a model mapping
|
if not isinstance(supported_models, list): # Create from a model mapping
|
||||||
supported_models_names = []
|
supported_models_names = []
|
||||||
|
if self.task in SUPPORTED_PEFT_TASKS:
|
||||||
|
supported_models_names.extend(SUPPORTED_PEFT_TASKS[self.task])
|
||||||
|
|
||||||
for _, model_name in supported_models.items():
|
for _, model_name in supported_models.items():
|
||||||
# Mapping can now contain tuples of models for the same configuration.
|
# Mapping can now contain tuples of models for the same configuration.
|
||||||
if isinstance(model_name, tuple):
|
if isinstance(model_name, tuple):
|
||||||
|
|||||||
@@ -814,3 +814,50 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
msg = "When using prompt learning PEFT methods such as PREFIX_TUNING"
|
msg = "When using prompt learning PEFT methods such as PREFIX_TUNING"
|
||||||
with self.assertRaisesRegex(RuntimeError, msg):
|
with self.assertRaisesRegex(RuntimeError, msg):
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
def test_peft_pipeline_no_warning(self):
|
||||||
|
"""
|
||||||
|
Test to verify that the warning message "The model 'PeftModel' is not supported for text-generation"
|
||||||
|
does not appear when using PeftModel with text-generation pipeline.
|
||||||
|
"""
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
ADAPTER_PATH = "peft-internal-testing/tiny-OPTForCausalLM-lora"
|
||||||
|
BASE_PATH = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
||||||
|
|
||||||
|
# Input text for testing
|
||||||
|
text = "Who is a Elon Musk?"
|
||||||
|
expected_error_msg = "The model 'PeftModel' is not supported for text-generation"
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
BASE_PATH,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(BASE_PATH)
|
||||||
|
|
||||||
|
lora_model = PeftModel.from_pretrained(
|
||||||
|
model,
|
||||||
|
ADAPTER_PATH,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create pipeline with PEFT model while capturing log output
|
||||||
|
# Check that the warning message is not present in the logs
|
||||||
|
pipeline_logger = logging.get_logger("transformers.pipelines.base")
|
||||||
|
with self.assertNoLogs(pipeline_logger, logging.ERROR) as cl:
|
||||||
|
lora_generator = pipeline(
|
||||||
|
task="text-generation",
|
||||||
|
model=lora_model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_length=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate text to verify pipeline works
|
||||||
|
_ = lora_generator(text)
|
||||||
|
|
||||||
|
# Check that the warning message is not present in the logs
|
||||||
|
self.assertNotIn(
|
||||||
|
expected_error_msg, cl.out, f"Error message '{expected_error_msg}' should not appear when using PeftModel"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user