From 2aef9a96011133f6b399b598fd69cfeca936eb37 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 3 Oct 2023 14:53:09 +0200 Subject: [PATCH] [`PEFT`] Final fixes (#26559) * fix issues with PEFT * logger warning futurewarning issues * fixup * adapt from suggestions * oops * rm test --- src/transformers/integrations/peft.py | 9 +++-- src/transformers/modeling_utils.py | 18 ++++++---- .../peft_integration/test_peft_integration.py | 36 +++++++++++++++++++ tests/quantization/bnb/test_mixed_int8.py | 7 ---- 4 files changed, 54 insertions(+), 16 deletions(-) diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index aa4fc083df..de68e01c5f 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import warnings from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from ..utils import ( @@ -159,6 +160,10 @@ class PeftAdapterMixin: "The one in `adapter_kwargs` will be used." ) + # Override token with adapter_kwargs' token + if "token" in adapter_kwargs: + token = adapter_kwargs.pop("token") + if peft_config is None: adapter_config_file = find_adapter_config_file( peft_model_id, @@ -381,8 +386,8 @@ class PeftAdapterMixin: return active_adapters def active_adapter(self) -> str: - logger.warning( - "The `active_adapter` method is deprecated and will be removed in a future version. ", FutureWarning + warnings.warn( + "The `active_adapter` method is deprecated and will be removed in a future version.", FutureWarning ) return self.active_adapters()[0] diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a548b019d8..5c3a121836 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1933,15 +1933,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if token is not None: kwargs["token"] = token + _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False) + # Checks if the model has been loaded in 8-bit - if getattr(self, "is_loaded_in_8bit", False) and getattr(self, "is_8bit_serializable", False): - warnings.warn( + if ( + getattr(self, "is_loaded_in_8bit", False) + and not getattr(self, "is_8bit_serializable", False) + and not _hf_peft_config_loaded + ): + raise ValueError( "You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected" - " behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed.", - UserWarning, + " behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed." ) - if getattr(self, "is_loaded_in_4bit", False): + # If the model has adapters attached, you can save the adapters + if getattr(self, "is_loaded_in_4bit", False) and not _hf_peft_config_loaded: raise NotImplementedError( "You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported" ) @@ -1982,8 +1988,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if self._auto_class is not None: custom_object_save(self, save_directory, config=self.config) - _hf_peft_config_loaded = getattr(model_to_save, "_hf_peft_config_loaded", False) - # Save the config if is_main_process: if not _hf_peft_config_loaded: diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index dbd0976dd4..809282c770 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -312,6 +312,42 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): # dummy generation _ = peft_model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)) + @require_torch_gpu + def test_peft_save_quantized(self): + """ + Simple test that tests the basic usage of PEFT model save_pretrained with quantized base models + """ + # 4bit + for model_id in self.peft_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + peft_model = transformers_class.from_pretrained(model_id, load_in_4bit=True, device_map="auto") + + module = peft_model.model.decoder.layers[0].self_attn.v_proj + self.assertTrue(module.__class__.__name__ == "Linear4bit") + self.assertTrue(peft_model.hf_device_map is not None) + + with tempfile.TemporaryDirectory() as tmpdirname: + peft_model.save_pretrained(tmpdirname) + self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname)) + self.assertTrue("adapter_config.json" in os.listdir(tmpdirname)) + self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname)) + + # 8-bit + for model_id in self.peft_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto") + + module = peft_model.model.decoder.layers[0].self_attn.v_proj + self.assertTrue(module.__class__.__name__ == "Linear8bitLt") + self.assertTrue(peft_model.hf_device_map is not None) + + with tempfile.TemporaryDirectory() as tmpdirname: + peft_model.save_pretrained(tmpdirname) + + self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname)) + self.assertTrue("adapter_config.json" in os.listdir(tmpdirname)) + self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname)) + def test_peft_pipeline(self): """ Simple test that tests the basic usage of PEFT model + pipeline diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 4ff3a32b33..670be57d0c 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -263,13 +263,6 @@ class MixedInt8Test(BaseMixedInt8Test): self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) - def test_warns_save_pretrained(self): - r""" - Test whether trying to save a model after converting it in 8-bit will throw a warning. - """ - with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname: - self.model_8bit.save_pretrained(tmpdirname) - def test_raise_if_config_and_load_in_8bit(self): r""" Test that loading the model with the config and `load_in_8bit` raises an error