From d53518c5f2dd7ada022ff5b725c684c9ed89cb44 Mon Sep 17 00:00:00 2001 From: BUI Van Tuan <37981884+bvantuan@users.noreply.github.com> Date: Tue, 1 Jul 2025 09:47:53 +0200 Subject: [PATCH] Fix key mapping for VLMs (#39029) * fix key mapping for VLMs * use __mro__ instead * update key mapping in save_pretrained --- src/transformers/modeling_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 515fb6d381..e99fb31ca3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3746,7 +3746,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi module_map[name + f".{key}"] = module state_dict = model_to_save.state_dict() - if any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS): + if any( + allowed_name in class_name.__name__.lower() + for class_name in self.__class__.__mro__[:-1] + for allowed_name in VLMS + ): reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()} original_state_dict = {} @@ -4402,7 +4406,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi key_mapping = kwargs.pop("key_mapping", None) # Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model - if key_mapping is None and any(allowed_name in cls.__name__.lower() for allowed_name in VLMS): + if key_mapping is None and any( + allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS + ): key_mapping = cls._checkpoint_conversion_mapping # Not used anymore -- remove them from the kwargs