[PEFT] Support low_cpu_mem_usage option for PEFT loading adapters (#33725)

* [PEFT] Support low_cpu_mem_usage for PEFT loading

PEFT added support for low_cpu_mem_usage=True when loading adapters in
https://github.com/huggingface/peft/pull/1961. This feature is now
available when installing PEFT v0.13.0. With this PR, this option is
also supported when loading PEFT adapters directly into transformers
models.

Additionally, with this PR,
https://github.com/huggingface/diffusers/pull/9510 will be unblocked,
which implements this option in diffusers.

* Fix typo
This commit is contained in:
Benjamin Bossan
2024-10-03 16:15:36 +02:00
committed by GitHub
parent bf0ffe3d29
commit 6500f78c86
2 changed files with 67 additions and 2 deletions

View File

@@ -11,10 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import inspect import inspect
import warnings import warnings
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from packaging import version
from ..utils import ( from ..utils import (
check_peft_version, check_peft_version,
find_adapter_config_file, find_adapter_config_file,
@@ -77,6 +80,7 @@ class PeftAdapterMixin:
offload_index: Optional[int] = None, offload_index: Optional[int] = None,
peft_config: Dict[str, Any] = None, peft_config: Dict[str, Any] = None,
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None, adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
low_cpu_mem_usage: bool = False,
adapter_kwargs: Optional[Dict[str, Any]] = None, adapter_kwargs: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
""" """
@@ -129,12 +133,27 @@ class PeftAdapterMixin:
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*): adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
dicts dicts
low_cpu_mem_usage (`bool`, *optional*, defaults to `False`):
Reduce memory usage while loading the PEFT adapter. This should also speed up the loading process.
Requires PEFT version 0.13.0 or higher.
adapter_kwargs (`Dict[str, Any]`, *optional*): adapter_kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
`find_adapter_config_file` method. `find_adapter_config_file` method.
""" """
check_peft_version(min_version=MIN_PEFT_VERSION) check_peft_version(min_version=MIN_PEFT_VERSION)
# peft only supports low_cpu_mem_usage starting from v0.13.0
peft_load_kwargs = {}
if low_cpu_mem_usage:
min_version_lcmu = "0.13.0"
if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu):
peft_load_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
else:
raise ValueError(
"The version of PEFT you are using does not support `low_cpu_mem_usage` yet, "
f"please install PEFT >= {min_version_lcmu}."
)
adapter_name = adapter_name if adapter_name is not None else "default" adapter_name = adapter_name if adapter_name is not None else "default"
if adapter_kwargs is None: if adapter_kwargs is None:
adapter_kwargs = {} adapter_kwargs = {}
@@ -192,7 +211,7 @@ class PeftAdapterMixin:
) )
# Create and add fresh new adapters into the model. # Create and add fresh new adapters into the model.
inject_adapter_in_model(peft_config, self, adapter_name) inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
if not self._hf_peft_config_loaded: if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True self._hf_peft_config_loaded = True
@@ -211,7 +230,9 @@ class PeftAdapterMixin:
processed_adapter_state_dict[new_key] = value processed_adapter_state_dict[new_key] = value
# Load state dict # Load state dict
incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name) incompatible_keys = set_peft_model_state_dict(
self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs
)
if incompatible_keys is not None: if incompatible_keys is not None:
# check only for unexpected keys # check only for unexpected keys

View File

@@ -12,11 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import os import os
import tempfile import tempfile
import unittest import unittest
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from packaging import version
from transformers import AutoModelForCausalLM, OPTForCausalLM from transformers import AutoModelForCausalLM, OPTForCausalLM
from transformers.testing_utils import ( from transformers.testing_utils import (
@@ -478,6 +480,48 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
# dummy generation # dummy generation
_ = model.generate(input_ids=dummy_input) _ = model.generate(input_ids=dummy_input)
def test_peft_add_adapter_with_state_dict_low_cpu_mem_usage(self):
"""
Check the usage of low_cpu_mem_usage, which is supported in PEFT >= 0.13.0
"""
from peft import LoraConfig
min_version_lcmu = "0.13.0"
is_lcmu_supported = version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu)
for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)
peft_config = LoraConfig()
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
dummy_state_dict = torch.load(state_dict_path)
# this should always work
model.load_adapter(
adapter_state_dict=dummy_state_dict, peft_config=peft_config, low_cpu_mem_usage=False
)
if is_lcmu_supported:
# if supported, this should not raise an error
model.load_adapter(
adapter_state_dict=dummy_state_dict,
adapter_name="other",
peft_config=peft_config,
low_cpu_mem_usage=True,
)
# after loading, no meta device should be remaining
self.assertFalse(any((p.device.type == "meta") for p in model.parameters()))
else:
err_msg = r"The version of PEFT you are using does not support `low_cpu_mem_usage` yet"
with self.assertRaisesRegex(ValueError, err_msg):
model.load_adapter(
adapter_state_dict=dummy_state_dict,
adapter_name="other",
peft_config=peft_config,
low_cpu_mem_usage=True,
)
def test_peft_from_pretrained_hub_kwargs(self): def test_peft_from_pretrained_hub_kwargs(self):
""" """
Tests different combinations of PEFT model + from_pretrained + hub kwargs Tests different combinations of PEFT model + from_pretrained + hub kwargs