🔴 [VLM] Add base model without head (#37033)

* i guessreverted all CdGen classes

* style

* llava onevision

* fix copies

* fix some tests

* some more tests

* dump

* skip these

* nevermind, i am dumb

* revert fix not needed

* fixup

* fixup

* another fixup

* more fixup to make ci finally happy

* fixup after rebasing

* fix qwen tests

* add internVL + typos here and there

* image token index -> id

* style

* fix init weights

* revert blip-2 not supported

* address comments

* fix copies

* revert blip2 test file as well

* as discussed internally, revert back CdGen models

* fix some tests

* fix more tests for compile

* CI red

* fix copies

* enumerate explicitly allowed models

* address comments

* fix tests

* fixup

* style again

* add tests for new model class

* another fixup ( x _ x )

* [fixup] unused attributes can be removed post-deprecation
This commit is contained in:
Raushan Turganbay
2025-05-07 17:47:51 +02:00
committed by GitHub
parent 3fa8d9c20e
commit 17742bd9c8
85 changed files with 7590 additions and 2904 deletions

View File

@@ -216,6 +216,28 @@ TORCH_INIT_FUNCTIONS = {
"kaiming_normal": nn.init.kaiming_normal,
}
# DO NOT MODIFY, KEPT FOR BC ONLY
VLMS = [
"aria",
"aya_vision",
"emu3",
"fuyu",
"got_ocr2",
"gemma3",
"internvl",
"llava",
"llava_next",
"llava_next_video",
"llava_onevision",
"mistral3",
"mllama",
"paligemma",
"qwen2_vl",
"qwem2_5_vl",
"video_llava",
"vipllava",
]
@contextmanager
def no_init_weights():
@@ -1778,6 +1800,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
main_input_name = "input_ids"
model_tags = None
_checkpoint_conversion_mapping = {} # used for BC support in VLMs, not meant to be used by new models
_auto_class = None
_no_split_modules = None
_skip_keys_device_placement = None
@@ -3484,6 +3508,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
module_map[name + f".{key}"] = module
state_dict = model_to_save.state_dict()
if any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS):
reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
original_state_dict = {}
for key, value in state_dict.items():
for pattern, replacement in reverse_key_mapping.items():
replacement = replacement.lstrip("^") # strip off un-needed chars and patterns
replacement = re.sub(r"\(.*?\)", "", pattern)
key, n_replace = re.subn(pattern, replacement, key)
# Early exit of the loop
if n_replace > 0:
break
original_state_dict[key] = value
state_dict = original_state_dict
# Translate state_dict from smp to hf if saving with smp >= 1.10
if IS_SAGEMAKER_MP_POST_1_10:
for smp_to_hf, _ in smp.state.module_manager.translate_functions:
@@ -4071,7 +4110,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
tp_size = kwargs.pop("tp_size", None)
key_mapping = kwargs.pop("key_mapping", None)
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
if any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
key_mapping = kwargs.pop("key_mapping", cls._checkpoint_conversion_mapping)
else:
key_mapping = kwargs.pop("key_mapping", None)
# Not used anymore -- remove them from the kwargs
_ = kwargs.pop("resume_download", None)
_ = kwargs.pop("trust_remote_code", None)