Add conversion for interleave llava (#31858)
* add conversion for interleave llava * remove debug lines * remove unused imports * Update src/transformers/models/llava/convert_llava_weights_to_hf.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * small changes + docs --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
ad35309a62
commit
97aa3e2905
@@ -40,8 +40,20 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/
|
|||||||
|
|
||||||
- Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results.
|
- Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results.
|
||||||
|
|
||||||
- For better results, we recommend users to prompt the model with the correct prompt format:
|
- For better results, we recommend users to prompt the model with the correct prompt format. Below is a list of prompt formats accepted by each llava checkpoint:
|
||||||
|
|
||||||
|
[llava-interleave models](https://huggingface.co/collections/llava-hf/llava-interleave-668e19a97da0036aad4a2f19) requires the following format:
|
||||||
|
```bash
|
||||||
|
"<|im_start|>user <image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant"
|
||||||
|
```
|
||||||
|
|
||||||
|
For multiple turns conversation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
"<|im_start|>user <image>\n<prompt1><|im_end|><|im_start|>assistant <answer1><|im_end|><|im_start|>user <image>\n<prompt1><|im_end|><|im_start|>assistant "
|
||||||
|
```
|
||||||
|
|
||||||
|
[llava-1.5 models](https://huggingface.co/collections/llava-hf/llava-15-65f762d5b6941db5c2ba07e0) requires the following format:
|
||||||
```bash
|
```bash
|
||||||
"USER: <image>\n<prompt> ASSISTANT:"
|
"USER: <image>\n<prompt> ASSISTANT:"
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -12,18 +12,21 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import argparse
|
import argparse
|
||||||
|
import glob
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AddedToken,
|
AddedToken,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
|
AutoImageProcessor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
CLIPImageProcessor,
|
|
||||||
LlavaConfig,
|
LlavaConfig,
|
||||||
LlavaForConditionalGeneration,
|
LlavaForConditionalGeneration,
|
||||||
LlavaProcessor,
|
LlavaProcessor,
|
||||||
|
SiglipVisionConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -48,6 +51,7 @@ Example for creating the old state dict file with Python:
|
|||||||
|
|
||||||
KEYS_TO_MODIFY_MAPPING = {
|
KEYS_TO_MODIFY_MAPPING = {
|
||||||
"model.vision_tower.": "",
|
"model.vision_tower.": "",
|
||||||
|
".vision_resampler": "", # all lmms-lab models do avg pooling, so no vision_resampler
|
||||||
"model.mm_projector": "multi_modal_projector",
|
"model.mm_projector": "multi_modal_projector",
|
||||||
"model": "model.model",
|
"model": "model.model",
|
||||||
"vision_model.model": "vision_model",
|
"vision_model.model": "vision_model",
|
||||||
@@ -58,6 +62,26 @@ KEYS_TO_MODIFY_MAPPING = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_original_state_dict(model_id):
|
||||||
|
directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])
|
||||||
|
|
||||||
|
original_state_dict = {}
|
||||||
|
for path in glob.glob(f"{directory_path}/*"):
|
||||||
|
if path.endswith(".safetensors"):
|
||||||
|
with safe_open(path, framework="pt", device="cpu") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
original_state_dict[key] = f.get_tensor(key)
|
||||||
|
|
||||||
|
# tied wieghts so lm.head is not saved. Let's clone to load state dict
|
||||||
|
if "lm_head.weight" not in original_state_dict:
|
||||||
|
original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone()
|
||||||
|
|
||||||
|
del original_state_dict["model.image_newline"] # not used in the original implementation because "merge_type=flat"
|
||||||
|
return original_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
# used only for llava-interlave
|
||||||
|
# for ex: Qwen/Qwen1.5-0.5B-Chat google/siglip-so400m-patch14-384 lmms-lab/llava-next-interleave-qwen-0.5b
|
||||||
def convert_state_dict_to_hf(state_dict):
|
def convert_state_dict_to_hf(state_dict):
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
@@ -77,24 +101,48 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o
|
|||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(text_model_id)
|
tokenizer = AutoTokenizer.from_pretrained(text_model_id)
|
||||||
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)
|
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)
|
||||||
|
if "Qwen" not in text_model_id: # qwen already has a pad token
|
||||||
tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||||
|
|
||||||
image_processor = CLIPImageProcessor.from_pretrained(vision_model_id)
|
image_processor = AutoImageProcessor.from_pretrained(vision_model_id)
|
||||||
|
|
||||||
processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||||
|
|
||||||
config = LlavaConfig(text_config=text_config)
|
if "Qwen" in text_model_id:
|
||||||
|
vision_config = SiglipVisionConfig(
|
||||||
|
hidden_size=1152,
|
||||||
|
image_size=384,
|
||||||
|
intermediate_size=4304,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_hidden_layers=26,
|
||||||
|
patch_size=14,
|
||||||
|
vision_use_head=False,
|
||||||
|
).to_dict()
|
||||||
|
else:
|
||||||
|
vision_config = None
|
||||||
|
|
||||||
|
config = LlavaConfig(
|
||||||
|
text_config=text_config,
|
||||||
|
vision_config=vision_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# llms-lab interleeave models do not use any selection startegy except for last hidden state
|
||||||
|
if "Qwen" in text_model_id:
|
||||||
|
config.image_token_index = 151646
|
||||||
|
config.vision_feature_select_strategy = "full"
|
||||||
|
config.vision_feature_layer = -1
|
||||||
|
else:
|
||||||
config.pad_token_id = 32001
|
config.pad_token_id = 32001
|
||||||
|
config.image_token_index = 32000
|
||||||
|
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
model = LlavaForConditionalGeneration(config)
|
model = LlavaForConditionalGeneration(config)
|
||||||
|
|
||||||
# Pad to 64 for performance reasons
|
if "Qwen" in text_model_id:
|
||||||
pad_shape = 64
|
state_dict = load_original_state_dict(old_state_dict_id)
|
||||||
|
else:
|
||||||
state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict.bin")
|
state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict.bin")
|
||||||
|
|
||||||
state_dict = torch.load(state_dict_path, map_location="cpu")
|
state_dict = torch.load(state_dict_path, map_location="cpu")
|
||||||
|
|
||||||
state_dict = convert_state_dict_to_hf(state_dict)
|
state_dict = convert_state_dict_to_hf(state_dict)
|
||||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||||
|
|
||||||
@@ -104,14 +152,18 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o
|
|||||||
sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
|
sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
|
||||||
dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma)
|
dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma)
|
||||||
|
|
||||||
# We add an image token so we resize the model
|
# We add an image token so we resize the model and pad to 64 for performance reasons
|
||||||
|
pad_shape = 64
|
||||||
|
vocab_size = config.text_config.vocab_size
|
||||||
model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape)
|
model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape)
|
||||||
model.language_model.model.embed_tokens.weight.data[32000:] = torch.stack(
|
model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack(
|
||||||
tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[32000:].shape[0]))),
|
tuple(
|
||||||
|
(dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0]))
|
||||||
|
),
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
model.language_model.lm_head.weight.data[32000:] = torch.stack(
|
model.language_model.lm_head.weight.data[vocab_size:] = torch.stack(
|
||||||
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))),
|
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))),
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user