From 2aff938992b756a6670f196e589a9ae6aa446b26 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 3 Mar 2025 18:01:43 +0000 Subject: [PATCH] Fix pipeline+peft interaction (#36480) * Fix pipeline-peft interaction * once again you have committed a debug breakpoint * Remove extra testing line * Add a test to check adapter loading * Correct adapter path * make fixup * Remove unnecessary check * Make check a little more stringent --- src/transformers/pipelines/__init__.py | 4 +++- tests/peft_integration/test_peft_integration.py | 10 +++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 257f5689b0..e57e1fac51 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -824,6 +824,7 @@ def pipeline( # Config is the primordial information item. # Instantiate config if needed + adapter_path = None if isinstance(config, str): config = AutoConfig.from_pretrained( config, _from_pipeline=task, code_revision=code_revision, **hub_kwargs, **model_kwargs @@ -844,6 +845,7 @@ def pipeline( if maybe_adapter_path is not None: with open(maybe_adapter_path, "r", encoding="utf-8") as f: adapter_config = json.load(f) + adapter_path = model model = adapter_config["base_model_name_or_path"] config = AutoConfig.from_pretrained( @@ -938,7 +940,7 @@ def pipeline( if isinstance(model, str) or framework is None: model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]} framework, model = infer_framework_load_model( - model, + adapter_path if adapter_path is not None else model, model_classes=model_classes, config=config, framework=framework, diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index f48584d612..5ebb53c5be 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -526,9 +526,13 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): """ from transformers import pipeline - for model_id in self.peft_test_model_ids: - pipe = pipeline("text-generation", model_id) - _ = pipe("Hello") + for adapter_id, base_model_id in zip(self.peft_test_model_ids, self.transformers_test_model_ids): + peft_pipe = pipeline("text-generation", adapter_id) + base_pipe = pipeline("text-generation", base_model_id) + peft_params = list(peft_pipe.model.parameters()) + base_params = list(base_pipe.model.parameters()) + self.assertNotEqual(len(peft_params), len(base_params)) # Assert we actually loaded the adapter too + _ = peft_pipe("Hello") def test_peft_add_adapter_with_state_dict(self): """