[PEFT] Final fixes (#26559)
* fix issues with PEFT * logger warning futurewarning issues * fixup * adapt from suggestions * oops * rm test
This commit is contained in:
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
@@ -159,6 +160,10 @@ class PeftAdapterMixin:
|
|||||||
"The one in `adapter_kwargs` will be used."
|
"The one in `adapter_kwargs` will be used."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Override token with adapter_kwargs' token
|
||||||
|
if "token" in adapter_kwargs:
|
||||||
|
token = adapter_kwargs.pop("token")
|
||||||
|
|
||||||
if peft_config is None:
|
if peft_config is None:
|
||||||
adapter_config_file = find_adapter_config_file(
|
adapter_config_file = find_adapter_config_file(
|
||||||
peft_model_id,
|
peft_model_id,
|
||||||
@@ -381,7 +386,7 @@ class PeftAdapterMixin:
|
|||||||
return active_adapters
|
return active_adapters
|
||||||
|
|
||||||
def active_adapter(self) -> str:
|
def active_adapter(self) -> str:
|
||||||
logger.warning(
|
warnings.warn(
|
||||||
"The `active_adapter` method is deprecated and will be removed in a future version.", FutureWarning
|
"The `active_adapter` method is deprecated and will be removed in a future version.", FutureWarning
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1933,15 +1933,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if token is not None:
|
if token is not None:
|
||||||
kwargs["token"] = token
|
kwargs["token"] = token
|
||||||
|
|
||||||
|
_hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False)
|
||||||
|
|
||||||
# Checks if the model has been loaded in 8-bit
|
# Checks if the model has been loaded in 8-bit
|
||||||
if getattr(self, "is_loaded_in_8bit", False) and getattr(self, "is_8bit_serializable", False):
|
if (
|
||||||
warnings.warn(
|
getattr(self, "is_loaded_in_8bit", False)
|
||||||
|
and not getattr(self, "is_8bit_serializable", False)
|
||||||
|
and not _hf_peft_config_loaded
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
|
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
|
||||||
" behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed.",
|
" behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed."
|
||||||
UserWarning,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if getattr(self, "is_loaded_in_4bit", False):
|
# If the model has adapters attached, you can save the adapters
|
||||||
|
if getattr(self, "is_loaded_in_4bit", False) and not _hf_peft_config_loaded:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported"
|
"You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported"
|
||||||
)
|
)
|
||||||
@@ -1982,8 +1988,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if self._auto_class is not None:
|
if self._auto_class is not None:
|
||||||
custom_object_save(self, save_directory, config=self.config)
|
custom_object_save(self, save_directory, config=self.config)
|
||||||
|
|
||||||
_hf_peft_config_loaded = getattr(model_to_save, "_hf_peft_config_loaded", False)
|
|
||||||
|
|
||||||
# Save the config
|
# Save the config
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
if not _hf_peft_config_loaded:
|
if not _hf_peft_config_loaded:
|
||||||
|
|||||||
@@ -312,6 +312,42 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
# dummy generation
|
# dummy generation
|
||||||
_ = peft_model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))
|
_ = peft_model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_peft_save_quantized(self):
|
||||||
|
"""
|
||||||
|
Simple test that tests the basic usage of PEFT model save_pretrained with quantized base models
|
||||||
|
"""
|
||||||
|
# 4bit
|
||||||
|
for model_id in self.peft_test_model_ids:
|
||||||
|
for transformers_class in self.transformers_test_model_classes:
|
||||||
|
peft_model = transformers_class.from_pretrained(model_id, load_in_4bit=True, device_map="auto")
|
||||||
|
|
||||||
|
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
||||||
|
self.assertTrue(module.__class__.__name__ == "Linear4bit")
|
||||||
|
self.assertTrue(peft_model.hf_device_map is not None)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
peft_model.save_pretrained(tmpdirname)
|
||||||
|
self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname))
|
||||||
|
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
|
||||||
|
self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname))
|
||||||
|
|
||||||
|
# 8-bit
|
||||||
|
for model_id in self.peft_test_model_ids:
|
||||||
|
for transformers_class in self.transformers_test_model_classes:
|
||||||
|
peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
|
||||||
|
|
||||||
|
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
||||||
|
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
|
||||||
|
self.assertTrue(peft_model.hf_device_map is not None)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
peft_model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname))
|
||||||
|
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
|
||||||
|
self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname))
|
||||||
|
|
||||||
def test_peft_pipeline(self):
|
def test_peft_pipeline(self):
|
||||||
"""
|
"""
|
||||||
Simple test that tests the basic usage of PEFT model + pipeline
|
Simple test that tests the basic usage of PEFT model + pipeline
|
||||||
|
|||||||
@@ -263,13 +263,6 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||||||
|
|
||||||
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||||
|
|
||||||
def test_warns_save_pretrained(self):
|
|
||||||
r"""
|
|
||||||
Test whether trying to save a model after converting it in 8-bit will throw a warning.
|
|
||||||
"""
|
|
||||||
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
self.model_8bit.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
def test_raise_if_config_and_load_in_8bit(self):
|
def test_raise_if_config_and_load_in_8bit(self):
|
||||||
r"""
|
r"""
|
||||||
Test that loading the model with the config and `load_in_8bit` raises an error
|
Test that loading the model with the config and `load_in_8bit` raises an error
|
||||||
|
|||||||
Reference in New Issue
Block a user