[PEFT] Support low_cpu_mem_usage option for PEFT loading adapters (#33725)
* [PEFT] Support low_cpu_mem_usage for PEFT loading PEFT added support for low_cpu_mem_usage=True when loading adapters in https://github.com/huggingface/peft/pull/1961. This feature is now available when installing PEFT v0.13.0. With this PR, this option is also supported when loading PEFT adapters directly into transformers models. Additionally, with this PR, https://github.com/huggingface/diffusers/pull/9510 will be unblocked, which implements this option in diffusers. * Fix typo
This commit is contained in:
@@ -12,11 +12,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from packaging import version
|
||||
|
||||
from transformers import AutoModelForCausalLM, OPTForCausalLM
|
||||
from transformers.testing_utils import (
|
||||
@@ -478,6 +480,48 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
# dummy generation
|
||||
_ = model.generate(input_ids=dummy_input)
|
||||
|
||||
def test_peft_add_adapter_with_state_dict_low_cpu_mem_usage(self):
|
||||
"""
|
||||
Check the usage of low_cpu_mem_usage, which is supported in PEFT >= 0.13.0
|
||||
"""
|
||||
from peft import LoraConfig
|
||||
|
||||
min_version_lcmu = "0.13.0"
|
||||
is_lcmu_supported = version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu)
|
||||
|
||||
for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
peft_config = LoraConfig()
|
||||
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
||||
dummy_state_dict = torch.load(state_dict_path)
|
||||
|
||||
# this should always work
|
||||
model.load_adapter(
|
||||
adapter_state_dict=dummy_state_dict, peft_config=peft_config, low_cpu_mem_usage=False
|
||||
)
|
||||
|
||||
if is_lcmu_supported:
|
||||
# if supported, this should not raise an error
|
||||
model.load_adapter(
|
||||
adapter_state_dict=dummy_state_dict,
|
||||
adapter_name="other",
|
||||
peft_config=peft_config,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
# after loading, no meta device should be remaining
|
||||
self.assertFalse(any((p.device.type == "meta") for p in model.parameters()))
|
||||
else:
|
||||
err_msg = r"The version of PEFT you are using does not support `low_cpu_mem_usage` yet"
|
||||
with self.assertRaisesRegex(ValueError, err_msg):
|
||||
model.load_adapter(
|
||||
adapter_state_dict=dummy_state_dict,
|
||||
adapter_name="other",
|
||||
peft_config=peft_config,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
def test_peft_from_pretrained_hub_kwargs(self):
|
||||
"""
|
||||
Tests different combinations of PEFT model + from_pretrained + hub kwargs
|
||||
|
||||
Reference in New Issue
Block a user