[PEFT] introducing adapter_kwargs for loading adapters from different Hub location (subfolder, revision) than the base model (#26270)

* make use of adapter_revision

* v1 adapter kwargs

* fix CI

* fix CI

* fix CI

* fixup

* add BC

* Update src/transformers/integrations/peft.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

* change it to error

* Update src/transformers/modeling_utils.py

* Update src/transformers/modeling_utils.py

* fixup

* change

* Update src/transformers/integrations/peft.py

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2023-09-28 11:13:03 +02:00
committed by GitHub
parent 52e2c13da3
commit 38e96324ef
6 changed files with 68 additions and 9 deletions

View File

@@ -77,6 +77,7 @@ class PeftAdapterMixin:
offload_index: Optional[int] = None, offload_index: Optional[int] = None,
peft_config: Dict[str, Any] = None, peft_config: Dict[str, Any] = None,
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None, adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
adapter_kwargs: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
""" """
Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
@@ -128,10 +129,15 @@ class PeftAdapterMixin:
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*): adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
dicts dicts
adapter_kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
`find_adapter_config_file` method.
""" """
check_peft_version(min_version=MIN_PEFT_VERSION) check_peft_version(min_version=MIN_PEFT_VERSION)
adapter_name = adapter_name if adapter_name is not None else "default" adapter_name = adapter_name if adapter_name is not None else "default"
if adapter_kwargs is None:
adapter_kwargs = {}
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
from peft.utils import set_peft_model_state_dict from peft.utils import set_peft_model_state_dict
@@ -144,11 +150,20 @@ class PeftAdapterMixin:
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter." "You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
) )
# We keep `revision` in the signature for backward compatibility
if revision is not None and "revision" not in adapter_kwargs:
adapter_kwargs["revision"] = revision
elif revision is not None and "revision" in adapter_kwargs and revision != adapter_kwargs["revision"]:
logger.error(
"You passed a `revision` argument both in `adapter_kwargs` and as a standalone argument. "
"The one in `adapter_kwargs` will be used."
)
if peft_config is None: if peft_config is None:
adapter_config_file = find_adapter_config_file( adapter_config_file = find_adapter_config_file(
peft_model_id, peft_model_id,
revision=revision,
token=token, token=token,
**adapter_kwargs,
) )
if adapter_config_file is None: if adapter_config_file is None:
@@ -159,8 +174,8 @@ class PeftAdapterMixin:
peft_config = PeftConfig.from_pretrained( peft_config = PeftConfig.from_pretrained(
peft_model_id, peft_model_id,
revision=revision,
use_auth_token=token, use_auth_token=token,
**adapter_kwargs,
) )
# Create and add fresh new adapters into the model. # Create and add fresh new adapters into the model.
@@ -170,7 +185,7 @@ class PeftAdapterMixin:
self._hf_peft_config_loaded = True self._hf_peft_config_loaded = True
if peft_model_id is not None: if peft_model_id is not None:
adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token) adapter_state_dict = load_peft_weights(peft_model_id, use_auth_token=token, **adapter_kwargs)
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility # We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
processed_adapter_state_dict = {} processed_adapter_state_dict = {}

View File

@@ -623,6 +623,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
subfolder = kwargs.pop("subfolder", "") subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
# Not relevant for Flax Models
_ = kwargs.pop("adapter_kwargs", None)
if use_auth_token is not None: if use_auth_token is not None:
warnings.warn( warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning

View File

@@ -2645,6 +2645,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None) tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None)
# Not relevant for TF models
_ = kwargs.pop("adapter_kwargs", None)
if use_auth_token is not None: if use_auth_token is not None:
warnings.warn( warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning

View File

@@ -2463,7 +2463,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
subfolder = kwargs.pop("subfolder", "") subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
_adapter_model_path = kwargs.pop("_adapter_model_path", None) adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default") adapter_name = kwargs.pop("adapter_name", "default")
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
@@ -2516,6 +2516,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
commit_hash = getattr(config, "_commit_hash", None) commit_hash = getattr(config, "_commit_hash", None)
if is_peft_available(): if is_peft_available():
_adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)
if _adapter_model_path is None: if _adapter_model_path is None:
_adapter_model_path = find_adapter_config_file( _adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
@@ -2525,14 +2527,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
token=token, token=token,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash, _commit_hash=commit_hash,
**adapter_kwargs,
) )
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path): if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
with open(_adapter_model_path, "r", encoding="utf-8") as f: with open(_adapter_model_path, "r", encoding="utf-8") as f:
_adapter_model_path = pretrained_model_name_or_path _adapter_model_path = pretrained_model_name_or_path
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"] pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
else:
_adapter_model_path = None
# change device_map into a map if we passed an int, a str or a torch.device # change device_map into a map if we passed an int, a str or a torch.device
if isinstance(device_map, torch.device): if isinstance(device_map, torch.device):
@@ -3371,8 +3374,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model.load_adapter( model.load_adapter(
_adapter_model_path, _adapter_model_path,
adapter_name=adapter_name, adapter_name=adapter_name,
revision=revision,
token=token, token=token,
adapter_kwargs=adapter_kwargs,
) )
if output_loading_info: if output_loading_info:

View File

@@ -469,6 +469,7 @@ class _BaseAutoModelClass:
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
code_revision = kwargs.pop("code_revision", None) code_revision = kwargs.pop("code_revision", None)
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
adapter_kwargs = kwargs.pop("adapter_kwargs", None)
revision = hub_kwargs.pop("revision", None) revision = hub_kwargs.pop("revision", None)
hub_kwargs["revision"] = sanitize_code_revision(pretrained_model_name_or_path, revision, trust_remote_code) hub_kwargs["revision"] = sanitize_code_revision(pretrained_model_name_or_path, revision, trust_remote_code)
@@ -503,15 +504,18 @@ class _BaseAutoModelClass:
commit_hash = getattr(config, "_commit_hash", None) commit_hash = getattr(config, "_commit_hash", None)
if is_peft_available(): if is_peft_available():
if adapter_kwargs is None:
adapter_kwargs = {}
maybe_adapter_path = find_adapter_config_file( maybe_adapter_path = find_adapter_config_file(
pretrained_model_name_or_path, _commit_hash=commit_hash, **hub_kwargs pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs
) )
if maybe_adapter_path is not None: if maybe_adapter_path is not None:
with open(maybe_adapter_path, "r", encoding="utf-8") as f: with open(maybe_adapter_path, "r", encoding="utf-8") as f:
adapter_config = json.load(f) adapter_config = json.load(f)
kwargs["_adapter_model_path"] = pretrained_model_name_or_path adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path
pretrained_model_name_or_path = adapter_config["base_model_name_or_path"] pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
@@ -545,6 +549,10 @@ class _BaseAutoModelClass:
trust_remote_code = resolve_trust_remote_code( trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
) )
# Set the adapter kwargs
kwargs["adapter_kwargs"] = adapter_kwargs
if has_remote_code and trust_remote_code: if has_remote_code and trust_remote_code:
class_ref = config.auto_map[cls.__name__] class_ref = config.auto_map[cls.__name__]
model_class = get_class_from_dynamic_module( model_class = get_class_from_dynamic_module(

View File

@@ -351,3 +351,30 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
# dummy generation # dummy generation
_ = model.generate(input_ids=dummy_input) _ = model.generate(input_ids=dummy_input)
def test_peft_from_pretrained_hub_kwargs(self):
"""
Tests different combinations of PEFT model + from_pretrained + hub kwargs
"""
peft_model_id = "peft-internal-testing/tiny-opt-lora-revision"
# This should not work
with self.assertRaises(OSError):
_ = AutoModelForCausalLM.from_pretrained(peft_model_id)
adapter_kwargs = {"revision": "test"}
# This should work
model = AutoModelForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
self.assertTrue(self._check_lora_correctly_converted(model))
model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
self.assertTrue(self._check_lora_correctly_converted(model))
adapter_kwargs = {"revision": "main", "subfolder": "test_subfolder"}
model = AutoModelForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
self.assertTrue(self._check_lora_correctly_converted(model))
model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
self.assertTrue(self._check_lora_correctly_converted(model))