🔴 [VLM] Add base model without head (#37033)
* i guessreverted all CdGen classes * style * llava onevision * fix copies * fix some tests * some more tests * dump * skip these * nevermind, i am dumb * revert fix not needed * fixup * fixup * another fixup * more fixup to make ci finally happy * fixup after rebasing * fix qwen tests * add internVL + typos here and there * image token index -> id * style * fix init weights * revert blip-2 not supported * address comments * fix copies * revert blip2 test file as well * as discussed internally, revert back CdGen models * fix some tests * fix more tests for compile * CI red * fix copies * enumerate explicitly allowed models * address comments * fix tests * fixup * style again * add tests for new model class * another fixup ( x _ x ) * [fixup] unused attributes can be removed post-deprecation
This commit is contained in:
committed by
GitHub
parent
3fa8d9c20e
commit
17742bd9c8
@@ -216,6 +216,28 @@ TORCH_INIT_FUNCTIONS = {
|
||||
"kaiming_normal": nn.init.kaiming_normal,
|
||||
}
|
||||
|
||||
# DO NOT MODIFY, KEPT FOR BC ONLY
|
||||
VLMS = [
|
||||
"aria",
|
||||
"aya_vision",
|
||||
"emu3",
|
||||
"fuyu",
|
||||
"got_ocr2",
|
||||
"gemma3",
|
||||
"internvl",
|
||||
"llava",
|
||||
"llava_next",
|
||||
"llava_next_video",
|
||||
"llava_onevision",
|
||||
"mistral3",
|
||||
"mllama",
|
||||
"paligemma",
|
||||
"qwen2_vl",
|
||||
"qwem2_5_vl",
|
||||
"video_llava",
|
||||
"vipllava",
|
||||
]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def no_init_weights():
|
||||
@@ -1778,6 +1800,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
main_input_name = "input_ids"
|
||||
model_tags = None
|
||||
|
||||
_checkpoint_conversion_mapping = {} # used for BC support in VLMs, not meant to be used by new models
|
||||
|
||||
_auto_class = None
|
||||
_no_split_modules = None
|
||||
_skip_keys_device_placement = None
|
||||
@@ -3484,6 +3508,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
module_map[name + f".{key}"] = module
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
if any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS):
|
||||
reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
|
||||
|
||||
original_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
for pattern, replacement in reverse_key_mapping.items():
|
||||
replacement = replacement.lstrip("^") # strip off un-needed chars and patterns
|
||||
replacement = re.sub(r"\(.*?\)", "", pattern)
|
||||
key, n_replace = re.subn(pattern, replacement, key)
|
||||
# Early exit of the loop
|
||||
if n_replace > 0:
|
||||
break
|
||||
original_state_dict[key] = value
|
||||
state_dict = original_state_dict
|
||||
|
||||
# Translate state_dict from smp to hf if saving with smp >= 1.10
|
||||
if IS_SAGEMAKER_MP_POST_1_10:
|
||||
for smp_to_hf, _ in smp.state.module_manager.translate_functions:
|
||||
@@ -4071,7 +4110,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
gguf_file = kwargs.pop("gguf_file", None)
|
||||
tp_plan = kwargs.pop("tp_plan", None)
|
||||
tp_size = kwargs.pop("tp_size", None)
|
||||
key_mapping = kwargs.pop("key_mapping", None)
|
||||
|
||||
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
||||
if any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
|
||||
key_mapping = kwargs.pop("key_mapping", cls._checkpoint_conversion_mapping)
|
||||
else:
|
||||
key_mapping = kwargs.pop("key_mapping", None)
|
||||
|
||||
# Not used anymore -- remove them from the kwargs
|
||||
_ = kwargs.pop("resume_download", None)
|
||||
_ = kwargs.pop("trust_remote_code", None)
|
||||
|
||||
Reference in New Issue
Block a user