From 97aa3e2905aac867952d89a7d13851d8adb7861c Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 10 Jul 2024 12:12:21 +0500 Subject: [PATCH] Add conversion for interleave llava (#31858) * add conversion for interleave llava * remove debug lines * remove unused imports * Update src/transformers/models/llava/convert_llava_weights_to_hf.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * small changes + docs --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- docs/source/en/model_doc/llava.md | 14 ++- .../llava/convert_llava_weights_to_hf.py | 86 +++++++++++++++---- 2 files changed, 82 insertions(+), 18 deletions(-) diff --git a/docs/source/en/model_doc/llava.md b/docs/source/en/model_doc/llava.md index 0ca6382714..43eaa41d5d 100644 --- a/docs/source/en/model_doc/llava.md +++ b/docs/source/en/model_doc/llava.md @@ -40,8 +40,20 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/ - Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results. -- For better results, we recommend users to prompt the model with the correct prompt format: +- For better results, we recommend users to prompt the model with the correct prompt format. Below is a list of prompt formats accepted by each llava checkpoint: +[llava-interleave models](https://huggingface.co/collections/llava-hf/llava-interleave-668e19a97da0036aad4a2f19) requires the following format: +```bash +"<|im_start|>user \nWhat is shown in this image?<|im_end|><|im_start|>assistant" +``` + +For multiple turns conversation: + +```bash +"<|im_start|>user \n<|im_end|><|im_start|>assistant <|im_end|><|im_start|>user \n<|im_end|><|im_start|>assistant " +``` + +[llava-1.5 models](https://huggingface.co/collections/llava-hf/llava-15-65f762d5b6941db5c2ba07e0) requires the following format: ```bash "USER: \n ASSISTANT:" ``` diff --git a/src/transformers/models/llava/convert_llava_weights_to_hf.py b/src/transformers/models/llava/convert_llava_weights_to_hf.py index bb40668f32..9841b7cb3d 100644 --- a/src/transformers/models/llava/convert_llava_weights_to_hf.py +++ b/src/transformers/models/llava/convert_llava_weights_to_hf.py @@ -12,18 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import glob import torch -from huggingface_hub import hf_hub_download +from huggingface_hub import hf_hub_download, snapshot_download +from safetensors import safe_open from transformers import ( AddedToken, AutoConfig, + AutoImageProcessor, AutoTokenizer, - CLIPImageProcessor, LlavaConfig, LlavaForConditionalGeneration, LlavaProcessor, + SiglipVisionConfig, ) @@ -48,6 +51,7 @@ Example for creating the old state dict file with Python: KEYS_TO_MODIFY_MAPPING = { "model.vision_tower.": "", + ".vision_resampler": "", # all lmms-lab models do avg pooling, so no vision_resampler "model.mm_projector": "multi_modal_projector", "model": "model.model", "vision_model.model": "vision_model", @@ -58,6 +62,26 @@ KEYS_TO_MODIFY_MAPPING = { } +def load_original_state_dict(model_id): + directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"]) + + original_state_dict = {} + for path in glob.glob(f"{directory_path}/*"): + if path.endswith(".safetensors"): + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + original_state_dict[key] = f.get_tensor(key) + + # tied wieghts so lm.head is not saved. Let's clone to load state dict + if "lm_head.weight" not in original_state_dict: + original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone() + + del original_state_dict["model.image_newline"] # not used in the original implementation because "merge_type=flat" + return original_state_dict + + +# used only for llava-interlave +# for ex: Qwen/Qwen1.5-0.5B-Chat google/siglip-so400m-patch14-384 lmms-lab/llava-next-interleave-qwen-0.5b def convert_state_dict_to_hf(state_dict): new_state_dict = {} for key, value in state_dict.items(): @@ -77,24 +101,48 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o tokenizer = AutoTokenizer.from_pretrained(text_model_id) tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) - tokenizer.add_special_tokens({"pad_token": ""}) - - image_processor = CLIPImageProcessor.from_pretrained(vision_model_id) + if "Qwen" not in text_model_id: # qwen already has a pad token + tokenizer.add_special_tokens({"pad_token": ""}) + image_processor = AutoImageProcessor.from_pretrained(vision_model_id) processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor) - config = LlavaConfig(text_config=text_config) - config.pad_token_id = 32001 + if "Qwen" in text_model_id: + vision_config = SiglipVisionConfig( + hidden_size=1152, + image_size=384, + intermediate_size=4304, + num_attention_heads=16, + num_hidden_layers=26, + patch_size=14, + vision_use_head=False, + ).to_dict() + else: + vision_config = None + + config = LlavaConfig( + text_config=text_config, + vision_config=vision_config, + ) + + # llms-lab interleeave models do not use any selection startegy except for last hidden state + if "Qwen" in text_model_id: + config.image_token_index = 151646 + config.vision_feature_select_strategy = "full" + config.vision_feature_layer = -1 + else: + config.pad_token_id = 32001 + config.image_token_index = 32000 with torch.device("meta"): model = LlavaForConditionalGeneration(config) - # Pad to 64 for performance reasons - pad_shape = 64 + if "Qwen" in text_model_id: + state_dict = load_original_state_dict(old_state_dict_id) + else: + state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict.bin") + state_dict = torch.load(state_dict_path, map_location="cpu") - state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict.bin") - - state_dict = torch.load(state_dict_path, map_location="cpu") state_dict = convert_state_dict_to_hf(state_dict) model.load_state_dict(state_dict, strict=True, assign=True) @@ -104,14 +152,18 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma) - # We add an image token so we resize the model + # We add an image token so we resize the model and pad to 64 for performance reasons + pad_shape = 64 + vocab_size = config.text_config.vocab_size model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape) - model.language_model.model.embed_tokens.weight.data[32000:] = torch.stack( - tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[32000:].shape[0]))), + model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack( + tuple( + (dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0])) + ), dim=0, ) - model.language_model.lm_head.weight.data[32000:] = torch.stack( - tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))), + model.language_model.lm_head.weight.data[vocab_size:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))), dim=0, )