Fix pipeline+peft interaction (#36480)
* Fix pipeline-peft interaction * once again you have committed a debug breakpoint * Remove extra testing line * Add a test to check adapter loading * Correct adapter path * make fixup * Remove unnecessary check * Make check a little more stringent
This commit is contained in:
@@ -824,6 +824,7 @@ def pipeline(
|
|||||||
|
|
||||||
# Config is the primordial information item.
|
# Config is the primordial information item.
|
||||||
# Instantiate config if needed
|
# Instantiate config if needed
|
||||||
|
adapter_path = None
|
||||||
if isinstance(config, str):
|
if isinstance(config, str):
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
config, _from_pipeline=task, code_revision=code_revision, **hub_kwargs, **model_kwargs
|
config, _from_pipeline=task, code_revision=code_revision, **hub_kwargs, **model_kwargs
|
||||||
@@ -844,6 +845,7 @@ def pipeline(
|
|||||||
if maybe_adapter_path is not None:
|
if maybe_adapter_path is not None:
|
||||||
with open(maybe_adapter_path, "r", encoding="utf-8") as f:
|
with open(maybe_adapter_path, "r", encoding="utf-8") as f:
|
||||||
adapter_config = json.load(f)
|
adapter_config = json.load(f)
|
||||||
|
adapter_path = model
|
||||||
model = adapter_config["base_model_name_or_path"]
|
model = adapter_config["base_model_name_or_path"]
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
@@ -938,7 +940,7 @@ def pipeline(
|
|||||||
if isinstance(model, str) or framework is None:
|
if isinstance(model, str) or framework is None:
|
||||||
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
|
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
|
||||||
framework, model = infer_framework_load_model(
|
framework, model = infer_framework_load_model(
|
||||||
model,
|
adapter_path if adapter_path is not None else model,
|
||||||
model_classes=model_classes,
|
model_classes=model_classes,
|
||||||
config=config,
|
config=config,
|
||||||
framework=framework,
|
framework=framework,
|
||||||
|
|||||||
@@ -526,9 +526,13 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
"""
|
"""
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
|
|
||||||
for model_id in self.peft_test_model_ids:
|
for adapter_id, base_model_id in zip(self.peft_test_model_ids, self.transformers_test_model_ids):
|
||||||
pipe = pipeline("text-generation", model_id)
|
peft_pipe = pipeline("text-generation", adapter_id)
|
||||||
_ = pipe("Hello")
|
base_pipe = pipeline("text-generation", base_model_id)
|
||||||
|
peft_params = list(peft_pipe.model.parameters())
|
||||||
|
base_params = list(base_pipe.model.parameters())
|
||||||
|
self.assertNotEqual(len(peft_params), len(base_params)) # Assert we actually loaded the adapter too
|
||||||
|
_ = peft_pipe("Hello")
|
||||||
|
|
||||||
def test_peft_add_adapter_with_state_dict(self):
|
def test_peft_add_adapter_with_state_dict(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user