[PEFT] introducing adapter_kwargs for loading adapters from different Hub location (subfolder, revision) than the base model (#26270)
* make use of adapter_revision * v1 adapter kwargs * fix CI * fix CI * fix CI * fixup * add BC * Update src/transformers/integrations/peft.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup * change it to error * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py * fixup * change * Update src/transformers/integrations/peft.py --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -77,6 +77,7 @@ class PeftAdapterMixin:
|
|||||||
offload_index: Optional[int] = None,
|
offload_index: Optional[int] = None,
|
||||||
peft_config: Dict[str, Any] = None,
|
peft_config: Dict[str, Any] = None,
|
||||||
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
|
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
|
||||||
|
adapter_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
|
Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
|
||||||
@@ -128,10 +129,15 @@ class PeftAdapterMixin:
|
|||||||
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
||||||
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
|
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
|
||||||
dicts
|
dicts
|
||||||
|
adapter_kwargs (`Dict[str, Any]`, *optional*):
|
||||||
|
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
|
||||||
|
`find_adapter_config_file` method.
|
||||||
"""
|
"""
|
||||||
check_peft_version(min_version=MIN_PEFT_VERSION)
|
check_peft_version(min_version=MIN_PEFT_VERSION)
|
||||||
|
|
||||||
adapter_name = adapter_name if adapter_name is not None else "default"
|
adapter_name = adapter_name if adapter_name is not None else "default"
|
||||||
|
if adapter_kwargs is None:
|
||||||
|
adapter_kwargs = {}
|
||||||
|
|
||||||
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
|
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
|
||||||
from peft.utils import set_peft_model_state_dict
|
from peft.utils import set_peft_model_state_dict
|
||||||
@@ -144,11 +150,20 @@ class PeftAdapterMixin:
|
|||||||
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
|
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# We keep `revision` in the signature for backward compatibility
|
||||||
|
if revision is not None and "revision" not in adapter_kwargs:
|
||||||
|
adapter_kwargs["revision"] = revision
|
||||||
|
elif revision is not None and "revision" in adapter_kwargs and revision != adapter_kwargs["revision"]:
|
||||||
|
logger.error(
|
||||||
|
"You passed a `revision` argument both in `adapter_kwargs` and as a standalone argument. "
|
||||||
|
"The one in `adapter_kwargs` will be used."
|
||||||
|
)
|
||||||
|
|
||||||
if peft_config is None:
|
if peft_config is None:
|
||||||
adapter_config_file = find_adapter_config_file(
|
adapter_config_file = find_adapter_config_file(
|
||||||
peft_model_id,
|
peft_model_id,
|
||||||
revision=revision,
|
|
||||||
token=token,
|
token=token,
|
||||||
|
**adapter_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if adapter_config_file is None:
|
if adapter_config_file is None:
|
||||||
@@ -159,8 +174,8 @@ class PeftAdapterMixin:
|
|||||||
|
|
||||||
peft_config = PeftConfig.from_pretrained(
|
peft_config = PeftConfig.from_pretrained(
|
||||||
peft_model_id,
|
peft_model_id,
|
||||||
revision=revision,
|
|
||||||
use_auth_token=token,
|
use_auth_token=token,
|
||||||
|
**adapter_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create and add fresh new adapters into the model.
|
# Create and add fresh new adapters into the model.
|
||||||
@@ -170,7 +185,7 @@ class PeftAdapterMixin:
|
|||||||
self._hf_peft_config_loaded = True
|
self._hf_peft_config_loaded = True
|
||||||
|
|
||||||
if peft_model_id is not None:
|
if peft_model_id is not None:
|
||||||
adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token)
|
adapter_state_dict = load_peft_weights(peft_model_id, use_auth_token=token, **adapter_kwargs)
|
||||||
|
|
||||||
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
|
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
|
||||||
processed_adapter_state_dict = {}
|
processed_adapter_state_dict = {}
|
||||||
|
|||||||
@@ -623,6 +623,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
subfolder = kwargs.pop("subfolder", "")
|
subfolder = kwargs.pop("subfolder", "")
|
||||||
commit_hash = kwargs.pop("_commit_hash", None)
|
commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
|
|
||||||
|
# Not relevant for Flax Models
|
||||||
|
_ = kwargs.pop("adapter_kwargs", None)
|
||||||
|
|
||||||
if use_auth_token is not None:
|
if use_auth_token is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
||||||
|
|||||||
@@ -2645,6 +2645,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
commit_hash = kwargs.pop("_commit_hash", None)
|
commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None)
|
tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None)
|
||||||
|
|
||||||
|
# Not relevant for TF models
|
||||||
|
_ = kwargs.pop("adapter_kwargs", None)
|
||||||
|
|
||||||
if use_auth_token is not None:
|
if use_auth_token is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
||||||
|
|||||||
@@ -2463,7 +2463,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
subfolder = kwargs.pop("subfolder", "")
|
subfolder = kwargs.pop("subfolder", "")
|
||||||
commit_hash = kwargs.pop("_commit_hash", None)
|
commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
variant = kwargs.pop("variant", None)
|
variant = kwargs.pop("variant", None)
|
||||||
_adapter_model_path = kwargs.pop("_adapter_model_path", None)
|
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
|
||||||
adapter_name = kwargs.pop("adapter_name", "default")
|
adapter_name = kwargs.pop("adapter_name", "default")
|
||||||
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
|
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
|
||||||
|
|
||||||
@@ -2516,6 +2516,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
commit_hash = getattr(config, "_commit_hash", None)
|
commit_hash = getattr(config, "_commit_hash", None)
|
||||||
|
|
||||||
if is_peft_available():
|
if is_peft_available():
|
||||||
|
_adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)
|
||||||
|
|
||||||
if _adapter_model_path is None:
|
if _adapter_model_path is None:
|
||||||
_adapter_model_path = find_adapter_config_file(
|
_adapter_model_path = find_adapter_config_file(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
@@ -2525,14 +2527,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
token=token,
|
token=token,
|
||||||
revision=revision,
|
|
||||||
subfolder=subfolder,
|
|
||||||
_commit_hash=commit_hash,
|
_commit_hash=commit_hash,
|
||||||
|
**adapter_kwargs,
|
||||||
)
|
)
|
||||||
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
|
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
|
||||||
with open(_adapter_model_path, "r", encoding="utf-8") as f:
|
with open(_adapter_model_path, "r", encoding="utf-8") as f:
|
||||||
_adapter_model_path = pretrained_model_name_or_path
|
_adapter_model_path = pretrained_model_name_or_path
|
||||||
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
|
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
|
||||||
|
else:
|
||||||
|
_adapter_model_path = None
|
||||||
|
|
||||||
# change device_map into a map if we passed an int, a str or a torch.device
|
# change device_map into a map if we passed an int, a str or a torch.device
|
||||||
if isinstance(device_map, torch.device):
|
if isinstance(device_map, torch.device):
|
||||||
@@ -3371,8 +3374,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
model.load_adapter(
|
model.load_adapter(
|
||||||
_adapter_model_path,
|
_adapter_model_path,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
revision=revision,
|
|
||||||
token=token,
|
token=token,
|
||||||
|
adapter_kwargs=adapter_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if output_loading_info:
|
if output_loading_info:
|
||||||
|
|||||||
@@ -469,6 +469,7 @@ class _BaseAutoModelClass:
|
|||||||
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
||||||
code_revision = kwargs.pop("code_revision", None)
|
code_revision = kwargs.pop("code_revision", None)
|
||||||
commit_hash = kwargs.pop("_commit_hash", None)
|
commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
|
adapter_kwargs = kwargs.pop("adapter_kwargs", None)
|
||||||
|
|
||||||
revision = hub_kwargs.pop("revision", None)
|
revision = hub_kwargs.pop("revision", None)
|
||||||
hub_kwargs["revision"] = sanitize_code_revision(pretrained_model_name_or_path, revision, trust_remote_code)
|
hub_kwargs["revision"] = sanitize_code_revision(pretrained_model_name_or_path, revision, trust_remote_code)
|
||||||
@@ -503,15 +504,18 @@ class _BaseAutoModelClass:
|
|||||||
commit_hash = getattr(config, "_commit_hash", None)
|
commit_hash = getattr(config, "_commit_hash", None)
|
||||||
|
|
||||||
if is_peft_available():
|
if is_peft_available():
|
||||||
|
if adapter_kwargs is None:
|
||||||
|
adapter_kwargs = {}
|
||||||
|
|
||||||
maybe_adapter_path = find_adapter_config_file(
|
maybe_adapter_path = find_adapter_config_file(
|
||||||
pretrained_model_name_or_path, _commit_hash=commit_hash, **hub_kwargs
|
pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if maybe_adapter_path is not None:
|
if maybe_adapter_path is not None:
|
||||||
with open(maybe_adapter_path, "r", encoding="utf-8") as f:
|
with open(maybe_adapter_path, "r", encoding="utf-8") as f:
|
||||||
adapter_config = json.load(f)
|
adapter_config = json.load(f)
|
||||||
|
|
||||||
kwargs["_adapter_model_path"] = pretrained_model_name_or_path
|
adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path
|
||||||
pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]
|
pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]
|
||||||
|
|
||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
@@ -545,6 +549,10 @@ class _BaseAutoModelClass:
|
|||||||
trust_remote_code = resolve_trust_remote_code(
|
trust_remote_code = resolve_trust_remote_code(
|
||||||
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
|
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set the adapter kwargs
|
||||||
|
kwargs["adapter_kwargs"] = adapter_kwargs
|
||||||
|
|
||||||
if has_remote_code and trust_remote_code:
|
if has_remote_code and trust_remote_code:
|
||||||
class_ref = config.auto_map[cls.__name__]
|
class_ref = config.auto_map[cls.__name__]
|
||||||
model_class = get_class_from_dynamic_module(
|
model_class = get_class_from_dynamic_module(
|
||||||
|
|||||||
@@ -351,3 +351,30 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
|
|
||||||
# dummy generation
|
# dummy generation
|
||||||
_ = model.generate(input_ids=dummy_input)
|
_ = model.generate(input_ids=dummy_input)
|
||||||
|
|
||||||
|
def test_peft_from_pretrained_hub_kwargs(self):
|
||||||
|
"""
|
||||||
|
Tests different combinations of PEFT model + from_pretrained + hub kwargs
|
||||||
|
"""
|
||||||
|
peft_model_id = "peft-internal-testing/tiny-opt-lora-revision"
|
||||||
|
|
||||||
|
# This should not work
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
_ = AutoModelForCausalLM.from_pretrained(peft_model_id)
|
||||||
|
|
||||||
|
adapter_kwargs = {"revision": "test"}
|
||||||
|
|
||||||
|
# This should work
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
|
||||||
|
self.assertTrue(self._check_lora_correctly_converted(model))
|
||||||
|
|
||||||
|
model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
|
||||||
|
self.assertTrue(self._check_lora_correctly_converted(model))
|
||||||
|
|
||||||
|
adapter_kwargs = {"revision": "main", "subfolder": "test_subfolder"}
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
|
||||||
|
self.assertTrue(self._check_lora_correctly_converted(model))
|
||||||
|
|
||||||
|
model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
|
||||||
|
self.assertTrue(self._check_lora_correctly_converted(model))
|
||||||
|
|||||||
Reference in New Issue
Block a user