Add conversion for interleave llava (#31858)
* add conversion for interleave llava * remove debug lines * remove unused imports * Update src/transformers/models/llava/convert_llava_weights_to_hf.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * small changes + docs --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
ad35309a62
commit
97aa3e2905
@@ -40,8 +40,20 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/
|
||||
|
||||
- Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results.
|
||||
|
||||
- For better results, we recommend users to prompt the model with the correct prompt format:
|
||||
- For better results, we recommend users to prompt the model with the correct prompt format. Below is a list of prompt formats accepted by each llava checkpoint:
|
||||
|
||||
[llava-interleave models](https://huggingface.co/collections/llava-hf/llava-interleave-668e19a97da0036aad4a2f19) requires the following format:
|
||||
```bash
|
||||
"<|im_start|>user <image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant"
|
||||
```
|
||||
|
||||
For multiple turns conversation:
|
||||
|
||||
```bash
|
||||
"<|im_start|>user <image>\n<prompt1><|im_end|><|im_start|>assistant <answer1><|im_end|><|im_start|>user <image>\n<prompt1><|im_end|><|im_start|>assistant "
|
||||
```
|
||||
|
||||
[llava-1.5 models](https://huggingface.co/collections/llava-hf/llava-15-65f762d5b6941db5c2ba07e0) requires the following format:
|
||||
```bash
|
||||
"USER: <image>\n<prompt> ASSISTANT:"
|
||||
```
|
||||
|
||||
@@ -12,18 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
AutoConfig,
|
||||
AutoImageProcessor,
|
||||
AutoTokenizer,
|
||||
CLIPImageProcessor,
|
||||
LlavaConfig,
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaProcessor,
|
||||
SiglipVisionConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -48,6 +51,7 @@ Example for creating the old state dict file with Python:
|
||||
|
||||
KEYS_TO_MODIFY_MAPPING = {
|
||||
"model.vision_tower.": "",
|
||||
".vision_resampler": "", # all lmms-lab models do avg pooling, so no vision_resampler
|
||||
"model.mm_projector": "multi_modal_projector",
|
||||
"model": "model.model",
|
||||
"vision_model.model": "vision_model",
|
||||
@@ -58,6 +62,26 @@ KEYS_TO_MODIFY_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
def load_original_state_dict(model_id):
|
||||
directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])
|
||||
|
||||
original_state_dict = {}
|
||||
for path in glob.glob(f"{directory_path}/*"):
|
||||
if path.endswith(".safetensors"):
|
||||
with safe_open(path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
original_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
# tied wieghts so lm.head is not saved. Let's clone to load state dict
|
||||
if "lm_head.weight" not in original_state_dict:
|
||||
original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone()
|
||||
|
||||
del original_state_dict["model.image_newline"] # not used in the original implementation because "merge_type=flat"
|
||||
return original_state_dict
|
||||
|
||||
|
||||
# used only for llava-interlave
|
||||
# for ex: Qwen/Qwen1.5-0.5B-Chat google/siglip-so400m-patch14-384 lmms-lab/llava-next-interleave-qwen-0.5b
|
||||
def convert_state_dict_to_hf(state_dict):
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
@@ -77,24 +101,48 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(text_model_id)
|
||||
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||
|
||||
image_processor = CLIPImageProcessor.from_pretrained(vision_model_id)
|
||||
if "Qwen" not in text_model_id: # qwen already has a pad token
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||
|
||||
image_processor = AutoImageProcessor.from_pretrained(vision_model_id)
|
||||
processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
|
||||
config = LlavaConfig(text_config=text_config)
|
||||
config.pad_token_id = 32001
|
||||
if "Qwen" in text_model_id:
|
||||
vision_config = SiglipVisionConfig(
|
||||
hidden_size=1152,
|
||||
image_size=384,
|
||||
intermediate_size=4304,
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=26,
|
||||
patch_size=14,
|
||||
vision_use_head=False,
|
||||
).to_dict()
|
||||
else:
|
||||
vision_config = None
|
||||
|
||||
config = LlavaConfig(
|
||||
text_config=text_config,
|
||||
vision_config=vision_config,
|
||||
)
|
||||
|
||||
# llms-lab interleeave models do not use any selection startegy except for last hidden state
|
||||
if "Qwen" in text_model_id:
|
||||
config.image_token_index = 151646
|
||||
config.vision_feature_select_strategy = "full"
|
||||
config.vision_feature_layer = -1
|
||||
else:
|
||||
config.pad_token_id = 32001
|
||||
config.image_token_index = 32000
|
||||
|
||||
with torch.device("meta"):
|
||||
model = LlavaForConditionalGeneration(config)
|
||||
|
||||
# Pad to 64 for performance reasons
|
||||
pad_shape = 64
|
||||
if "Qwen" in text_model_id:
|
||||
state_dict = load_original_state_dict(old_state_dict_id)
|
||||
else:
|
||||
state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict.bin")
|
||||
state_dict = torch.load(state_dict_path, map_location="cpu")
|
||||
|
||||
state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict.bin")
|
||||
|
||||
state_dict = torch.load(state_dict_path, map_location="cpu")
|
||||
state_dict = convert_state_dict_to_hf(state_dict)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
@@ -104,14 +152,18 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o
|
||||
sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
|
||||
dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma)
|
||||
|
||||
# We add an image token so we resize the model
|
||||
# We add an image token so we resize the model and pad to 64 for performance reasons
|
||||
pad_shape = 64
|
||||
vocab_size = config.text_config.vocab_size
|
||||
model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape)
|
||||
model.language_model.model.embed_tokens.weight.data[32000:] = torch.stack(
|
||||
tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[32000:].shape[0]))),
|
||||
model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack(
|
||||
tuple(
|
||||
(dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0]))
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
model.language_model.lm_head.weight.data[32000:] = torch.stack(
|
||||
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))),
|
||||
model.language_model.lm_head.weight.data[vocab_size:] = torch.stack(
|
||||
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user