[PixtralLarge] Update Pixtral conversion script to support large format! (#34801)
* update conversion script * update for bias again * remove pdv * use my dir * Update how we initialize the tokenizer * Convert in bfloat16 * Undo that one again * fix config dump * .to() was broken for BatchMixFeature * quick debug breakpoint * put the breakpoint in the right place * Add a config flag for the multimodal projector bias * Add a config flag for the multimodal projector bias * Conversion script can load chat templates * Indent config for comparison * Stop clobbering the config * Re-enable the config clobber * Get rid of the config manual save - it has no effect! * Handle adapter bias correctly * Default vision transformer activation to silu * Remove legacy processing path * One commit with all the debug breakpoints before I delete them all, in case I need to revert * Update conversion * Remove vLLM debugging instrumentation * Drop xformers * Remove debug enumerates * make fixup * make fixup * Break copied from in pixtral * Propagate multimodal_projector_bias change * Propagate multimodal_projector_bias change * Remove debug device .to() * Restore attention weights output * Fix Pixtral test * Drop image_seq_length * Drop image_seq_length * Put the legacy processing code back * Add the bias option to the llava_next_video config * Add the bias option to the llava_next_video config * Make certain args required in converter * Make certain args required in converter * typo * make fixup * Reverting some dtype changes since it seems to work without them --------- Co-authored-by: arthur@huggingface.co <arthur@ip-26-0-166-244.ec2.internal> Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
@@ -50,6 +50,8 @@ class LlavaConfig(PretrainedConfig):
|
|||||||
The index of the layer to select the vision feature.
|
The index of the layer to select the vision feature.
|
||||||
image_seq_length (`int`, *optional*, defaults to 576):
|
image_seq_length (`int`, *optional*, defaults to 576):
|
||||||
Sequence length of one image embedding.
|
Sequence length of one image embedding.
|
||||||
|
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use bias in the multimodal projector.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -85,6 +87,7 @@ class LlavaConfig(PretrainedConfig):
|
|||||||
vision_feature_select_strategy="default",
|
vision_feature_select_strategy="default",
|
||||||
vision_feature_layer=-2,
|
vision_feature_layer=-2,
|
||||||
image_seq_length=576,
|
image_seq_length=576,
|
||||||
|
multimodal_projector_bias=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
@@ -127,6 +130,7 @@ class LlavaConfig(PretrainedConfig):
|
|||||||
text_config = CONFIG_MAPPING["llama"]()
|
text_config = CONFIG_MAPPING["llama"]()
|
||||||
|
|
||||||
self.text_config = text_config
|
self.text_config = text_config
|
||||||
|
self.multimodal_projector_bias = multimodal_projector_bias
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -86,10 +86,13 @@ class LlavaCausalLMOutputWithPast(ModelOutput):
|
|||||||
class LlavaMultiModalProjector(nn.Module):
|
class LlavaMultiModalProjector(nn.Module):
|
||||||
def __init__(self, config: LlavaConfig):
|
def __init__(self, config: LlavaConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.linear_1 = nn.Linear(
|
||||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||||
|
)
|
||||||
self.act = ACT2FN[config.projector_hidden_act]
|
self.act = ACT2FN[config.projector_hidden_act]
|
||||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
self.linear_2 = nn.Linear(
|
||||||
|
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, image_features):
|
def forward(self, image_features):
|
||||||
hidden_states = self.linear_1(image_features)
|
hidden_states = self.linear_1(image_features)
|
||||||
|
|||||||
@@ -55,6 +55,8 @@ class LlavaNextConfig(PretrainedConfig):
|
|||||||
Whether the model's input and output word embeddings should be tied.
|
Whether the model's input and output word embeddings should be tied.
|
||||||
image_seq_length (`int`, *optional*, defaults to 576):
|
image_seq_length (`int`, *optional*, defaults to 576):
|
||||||
Sequence length of one image embedding.
|
Sequence length of one image embedding.
|
||||||
|
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use bias in the multimodal projector.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -92,12 +94,14 @@ class LlavaNextConfig(PretrainedConfig):
|
|||||||
image_grid_pinpoints=None,
|
image_grid_pinpoints=None,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
image_seq_length=576,
|
image_seq_length=576,
|
||||||
|
multimodal_projector_bias=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
self.image_token_index = image_token_index
|
self.image_token_index = image_token_index
|
||||||
self.projector_hidden_act = projector_hidden_act
|
self.projector_hidden_act = projector_hidden_act
|
||||||
self.image_seq_length = image_seq_length
|
self.image_seq_length = image_seq_length
|
||||||
|
self.multimodal_projector_bias = multimodal_projector_bias
|
||||||
|
|
||||||
if vision_feature_select_strategy not in ["default", "full"]:
|
if vision_feature_select_strategy not in ["default", "full"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -194,10 +194,13 @@ class LlavaNextCausalLMOutputWithPast(ModelOutput):
|
|||||||
class LlavaNextMultiModalProjector(nn.Module):
|
class LlavaNextMultiModalProjector(nn.Module):
|
||||||
def __init__(self, config: LlavaNextConfig):
|
def __init__(self, config: LlavaNextConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.linear_1 = nn.Linear(
|
||||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||||
|
)
|
||||||
self.act = ACT2FN[config.projector_hidden_act]
|
self.act = ACT2FN[config.projector_hidden_act]
|
||||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
self.linear_2 = nn.Linear(
|
||||||
|
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, image_features):
|
def forward(self, image_features):
|
||||||
hidden_states = self.linear_1(image_features)
|
hidden_states = self.linear_1(image_features)
|
||||||
|
|||||||
@@ -44,6 +44,8 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
|||||||
The image token index to encode the image prompt.
|
The image token index to encode the image prompt.
|
||||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||||
The activation function used by the multimodal projector.
|
The activation function used by the multimodal projector.
|
||||||
|
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use bias in the multimodal projector.
|
||||||
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
||||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
|
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
|
||||||
@@ -95,6 +97,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
|||||||
ignore_index=-100,
|
ignore_index=-100,
|
||||||
image_token_index=32001,
|
image_token_index=32001,
|
||||||
projector_hidden_act="gelu",
|
projector_hidden_act="gelu",
|
||||||
|
multimodal_projector_bias=True,
|
||||||
vision_feature_select_strategy="default",
|
vision_feature_select_strategy="default",
|
||||||
vision_feature_layer=-2,
|
vision_feature_layer=-2,
|
||||||
image_grid_pinpoints=None,
|
image_grid_pinpoints=None,
|
||||||
@@ -114,6 +117,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
|||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
self.image_token_index = image_token_index
|
self.image_token_index = image_token_index
|
||||||
self.projector_hidden_act = projector_hidden_act
|
self.projector_hidden_act = projector_hidden_act
|
||||||
|
self.multimodal_projector_bias = multimodal_projector_bias
|
||||||
|
|
||||||
if vision_feature_select_strategy not in ["default", "full"]:
|
if vision_feature_select_strategy not in ["default", "full"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -179,10 +179,13 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
|
|||||||
class LlavaNextVideoMultiModalProjector(nn.Module):
|
class LlavaNextVideoMultiModalProjector(nn.Module):
|
||||||
def __init__(self, config: LlavaNextVideoConfig):
|
def __init__(self, config: LlavaNextVideoConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.linear_1 = nn.Linear(
|
||||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||||
|
)
|
||||||
self.act = ACT2FN[config.projector_hidden_act]
|
self.act = ACT2FN[config.projector_hidden_act]
|
||||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
self.linear_2 = nn.Linear(
|
||||||
|
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, image_features):
|
def forward(self, image_features):
|
||||||
hidden_states = self.linear_1(image_features)
|
hidden_states = self.linear_1(image_features)
|
||||||
|
|||||||
@@ -58,6 +58,8 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
|||||||
The image token index to encode the image prompt.
|
The image token index to encode the image prompt.
|
||||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||||
The activation function used by the multimodal projector.
|
The activation function used by the multimodal projector.
|
||||||
|
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use bias in the multimodal projector.
|
||||||
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
||||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
|
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
|
||||||
@@ -109,6 +111,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
|||||||
ignore_index=-100,
|
ignore_index=-100,
|
||||||
image_token_index=32001,
|
image_token_index=32001,
|
||||||
projector_hidden_act="gelu",
|
projector_hidden_act="gelu",
|
||||||
|
multimodal_projector_bias=True,
|
||||||
vision_feature_select_strategy="default",
|
vision_feature_select_strategy="default",
|
||||||
vision_feature_layer=-2,
|
vision_feature_layer=-2,
|
||||||
image_grid_pinpoints=None,
|
image_grid_pinpoints=None,
|
||||||
@@ -128,6 +131,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
|||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
self.image_token_index = image_token_index
|
self.image_token_index = image_token_index
|
||||||
self.projector_hidden_act = projector_hidden_act
|
self.projector_hidden_act = projector_hidden_act
|
||||||
|
self.multimodal_projector_bias = multimodal_projector_bias
|
||||||
|
|
||||||
if vision_feature_select_strategy not in ["default", "full"]:
|
if vision_feature_select_strategy not in ["default", "full"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -58,6 +58,8 @@ class LlavaOnevisionConfig(PretrainedConfig):
|
|||||||
of the form `(height, width)`.
|
of the form `(height, width)`.
|
||||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
Whether the model's input and output word embeddings should be tied.
|
Whether the model's input and output word embeddings should be tied.
|
||||||
|
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use bias in the multimodal projector.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -95,11 +97,13 @@ class LlavaOnevisionConfig(PretrainedConfig):
|
|||||||
vision_aspect_ratio="anyres_max_9",
|
vision_aspect_ratio="anyres_max_9",
|
||||||
image_grid_pinpoints=None,
|
image_grid_pinpoints=None,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
|
multimodal_projector_bias=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.image_token_index = image_token_index
|
self.image_token_index = image_token_index
|
||||||
self.video_token_index = video_token_index
|
self.video_token_index = video_token_index
|
||||||
self.projector_hidden_act = projector_hidden_act
|
self.projector_hidden_act = projector_hidden_act
|
||||||
|
self.multimodal_projector_bias = multimodal_projector_bias
|
||||||
|
|
||||||
if vision_feature_select_strategy not in ["default", "full"]:
|
if vision_feature_select_strategy not in ["default", "full"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -201,10 +201,13 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput):
|
|||||||
class LlavaOnevisionMultiModalProjector(nn.Module):
|
class LlavaOnevisionMultiModalProjector(nn.Module):
|
||||||
def __init__(self, config: LlavaOnevisionConfig):
|
def __init__(self, config: LlavaOnevisionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.linear_1 = nn.Linear(
|
||||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||||
|
)
|
||||||
self.act = ACT2FN[config.projector_hidden_act]
|
self.act = ACT2FN[config.projector_hidden_act]
|
||||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
self.linear_2 = nn.Linear(
|
||||||
|
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, image_features):
|
def forward(self, image_features):
|
||||||
hidden_states = self.linear_1(image_features)
|
hidden_states = self.linear_1(image_features)
|
||||||
|
|||||||
@@ -12,6 +12,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
@@ -27,7 +29,6 @@ from transformers import (
|
|||||||
PixtralImageProcessor,
|
PixtralImageProcessor,
|
||||||
PixtralProcessor,
|
PixtralProcessor,
|
||||||
PixtralVisionConfig,
|
PixtralVisionConfig,
|
||||||
PreTrainedTokenizerFast,
|
|
||||||
)
|
)
|
||||||
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||||||
|
|
||||||
@@ -156,29 +157,18 @@ class MistralConverter:
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
def convert_mistral_tokenizer():
|
def convert_mistral_tokenizer(model_file):
|
||||||
model_name = "mistralai/Pixtral-12B-2409"
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
tokenizer = MistralTokenizer.from_model(model_name)
|
mistral_tokenizer = MistralTokenizer.from_file(model_file)
|
||||||
|
vocab = mistral_tokenizer.instruct_tokenizer.tokenizer.vocab()
|
||||||
vocab = tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial
|
control_token_ids = mistral_tokenizer.instruct_tokenizer.tokenizer._control_tokens
|
||||||
all_special = [
|
all_special = [vocab[id] for id in control_token_ids]
|
||||||
token.value if hasattr(token, "value") else token
|
hf_tokenizer = LlamaTokenizer(model_file)
|
||||||
for token in tokenizer.instruct_tokenizer.tokenizer._all_special_tokens
|
# Do I need to exclude tokens that are already special?
|
||||||
]
|
hf_tokenizer.add_special_tokens({"additional_special_tokens": all_special})
|
||||||
specials_tokens = {token: all_special.index(token) for token in all_special}
|
hf_tokenizer.model_input_names = ["input_ids", "attention_mask"]
|
||||||
specials_tokens.update(vocab)
|
return hf_tokenizer
|
||||||
vocab = specials_tokens
|
|
||||||
|
|
||||||
tokenizer = PreTrainedTokenizerFast(
|
|
||||||
tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(),
|
|
||||||
bos_token="<s>",
|
|
||||||
unk_token="<unk>",
|
|
||||||
eos_token="</s>",
|
|
||||||
)
|
|
||||||
tokenizer.model_input_names = ["input_ids", "attention_mask"]
|
|
||||||
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def permute_for_rope(value, n_heads, config):
|
def permute_for_rope(value, n_heads, config):
|
||||||
@@ -187,7 +177,7 @@ def permute_for_rope(value, n_heads, config):
|
|||||||
return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||||
|
|
||||||
|
|
||||||
def convert_dictionnary(original_state_dict, vision_config, text_config):
|
def convert_dictionary(original_state_dict, vision_config, text_config):
|
||||||
new_dict = {}
|
new_dict = {}
|
||||||
|
|
||||||
all_keys = "\n" + "\n".join(original_state_dict.keys())
|
all_keys = "\n" + "\n".join(original_state_dict.keys())
|
||||||
@@ -208,7 +198,6 @@ def convert_dictionnary(original_state_dict, vision_config, text_config):
|
|||||||
num_attention_heads = _config.num_attention_heads
|
num_attention_heads = _config.num_attention_heads
|
||||||
if "k_proj" in new_key:
|
if "k_proj" in new_key:
|
||||||
num_attention_heads = _config.num_key_value_heads
|
num_attention_heads = _config.num_key_value_heads
|
||||||
# convert the text model (basically mistral model)
|
|
||||||
|
|
||||||
if "q_proj" in new_key or "k_proj" in new_key:
|
if "q_proj" in new_key or "k_proj" in new_key:
|
||||||
value = permute_for_rope(value, num_attention_heads, _config)
|
value = permute_for_rope(value, num_attention_heads, _config)
|
||||||
@@ -217,29 +206,57 @@ def convert_dictionnary(original_state_dict, vision_config, text_config):
|
|||||||
return new_dict
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
def convert_mistral_model(input_dir, output_dir):
|
MISTRAL_CONFIG_MAPPING = {
|
||||||
text_config = MistralConfig(
|
"dim": "hidden_size",
|
||||||
attention_dropout=0.0,
|
"hidden_dim": "intermediate_size",
|
||||||
bos_token_id=1,
|
"n_kv_heads": "num_key_value_heads",
|
||||||
eos_token_id=2,
|
"n_heads": "num_attention_heads",
|
||||||
head_dim=128,
|
"n_layers": "num_hidden_layers",
|
||||||
hidden_act="silu",
|
}
|
||||||
hidden_size=5120,
|
|
||||||
initializer_range=0.02,
|
|
||||||
intermediate_size=14336,
|
|
||||||
max_position_embeddings=1024000,
|
|
||||||
model_type="mistral",
|
|
||||||
num_attention_heads=32,
|
|
||||||
num_hidden_layers=40,
|
|
||||||
num_key_value_heads=8,
|
|
||||||
rms_norm_eps=1e-05,
|
|
||||||
rope_theta=1000000000.0,
|
|
||||||
sliding_window=None,
|
|
||||||
tie_word_embeddings=False,
|
|
||||||
vocab_size=131072,
|
|
||||||
)
|
|
||||||
|
|
||||||
vision_config = PixtralVisionConfig()
|
|
||||||
|
def convert_mistral_model(input_dir, output_dir):
|
||||||
|
vision_config = {}
|
||||||
|
if os.path.isfile(f"{input_dir}/params.json"):
|
||||||
|
with open(f"{input_dir}/params.json") as f:
|
||||||
|
param_json = json.load(f)
|
||||||
|
vision_config = param_json.pop("vision_encoder")
|
||||||
|
for k, v in MISTRAL_CONFIG_MAPPING.items():
|
||||||
|
value = param_json.pop(k)
|
||||||
|
param_json[v] = value
|
||||||
|
if "hidden_act" not in vision_config:
|
||||||
|
vision_config["hidden_act"] = "silu"
|
||||||
|
text_config = MistralConfig(
|
||||||
|
**param_json,
|
||||||
|
hidden_act="silu",
|
||||||
|
sliding_window=None,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
is_composition=True,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text_config = MistralConfig(
|
||||||
|
attention_dropout=0.0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
head_dim=128,
|
||||||
|
hidden_act="silu",
|
||||||
|
hidden_size=5120,
|
||||||
|
initializer_range=0.02,
|
||||||
|
intermediate_size=14336,
|
||||||
|
max_position_embeddings=1024000,
|
||||||
|
model_type="mistral",
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_hidden_layers=40,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
rms_norm_eps=1e-05,
|
||||||
|
rope_theta=1000000000.0,
|
||||||
|
sliding_window=None,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
vocab_size=131072,
|
||||||
|
)
|
||||||
|
adapter_bias = vision_config.pop("adapter_bias", True)
|
||||||
|
vision_config = PixtralVisionConfig(**vision_config)
|
||||||
config = LlavaConfig(
|
config = LlavaConfig(
|
||||||
vision_config,
|
vision_config,
|
||||||
text_config,
|
text_config,
|
||||||
@@ -247,38 +264,55 @@ def convert_mistral_model(input_dir, output_dir):
|
|||||||
image_token_index=10,
|
image_token_index=10,
|
||||||
vision_feature_select_strategy="full",
|
vision_feature_select_strategy="full",
|
||||||
image_seq_length=1,
|
image_seq_length=1,
|
||||||
|
multimodal_projector_bias=adapter_bias,
|
||||||
)
|
)
|
||||||
config.architectures = ["LlavaForConditionalGeneration"]
|
config.architectures = ["LlavaForConditionalGeneration"]
|
||||||
config.save_pretrained(output_dir)
|
config.save_pretrained(output_dir)
|
||||||
|
full_original_state_dict = {}
|
||||||
|
safetensors_files = sorted([file for file in os.listdir(input_dir) if file.endswith(".safetensors")])
|
||||||
|
if len(safetensors_files) == 1:
|
||||||
|
full_original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors")
|
||||||
|
else:
|
||||||
|
for file in safetensors_files:
|
||||||
|
loaded_dict = safe_load_file(f"{input_dir}/{file}")
|
||||||
|
full_original_state_dict.update(loaded_dict)
|
||||||
|
|
||||||
original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors")
|
new_dict = convert_dictionary(full_original_state_dict, vision_config, text_config)
|
||||||
new_dict = convert_dictionnary(original_state_dict, vision_config, text_config)
|
|
||||||
|
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
model = LlavaForConditionalGeneration(config)
|
model = LlavaForConditionalGeneration(config)
|
||||||
model.load_state_dict(new_dict, strict=True, assign=True)
|
model.load_state_dict(new_dict, strict=True, assign=True)
|
||||||
|
|
||||||
model.save_pretrained(output_dir)
|
model.save_pretrained(output_dir)
|
||||||
|
|
||||||
tokenizer = convert_mistral_tokenizer()
|
|
||||||
image_processor = PixtralImageProcessor()
|
|
||||||
processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor, image_token="[IMG]")
|
|
||||||
processor.save_pretrained(output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input_dir",
|
"--input_dir",
|
||||||
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
|
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
|
||||||
|
required=True,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output_dir",
|
"--output_dir",
|
||||||
help="Location to write HF model and tokenizer",
|
help="Location to write HF model and tokenizer",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer_file", help="Location of the specific tokenizer model file to use.", required=True
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chat_template_file",
|
||||||
|
help="Optional file containing a raw chat template. Will be set as the processor's chat template.",
|
||||||
|
required=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_mistral_model(args.input_dir, args.output_dir)
|
convert_mistral_model(args.input_dir, args.output_dir)
|
||||||
|
tokenizer = convert_mistral_tokenizer(args.tokenizer_file)
|
||||||
|
image_processor = PixtralImageProcessor()
|
||||||
|
processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor, image_token="[IMG]")
|
||||||
|
if args.chat_template_file:
|
||||||
|
processor.chat_template = open(args.chat_template_file).read()
|
||||||
|
processor.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ from ...image_utils import (
|
|||||||
validate_kwargs,
|
validate_kwargs,
|
||||||
validate_preprocess_arguments,
|
validate_preprocess_arguments,
|
||||||
)
|
)
|
||||||
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, is_vision_available, logging
|
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_vision_available, logging
|
||||||
from ...utils.import_utils import requires_backends
|
from ...utils.import_utils import requires_backends
|
||||||
|
|
||||||
|
|
||||||
@@ -63,10 +63,24 @@ class BatchMixFeature(BatchFeature):
|
|||||||
Returns:
|
Returns:
|
||||||
[`BatchFeature`]: The same instance after modification.
|
[`BatchFeature`]: The same instance after modification.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _recursive_to(obj, device, *args, **kwargs):
|
||||||
|
# Lists can be nested, so keep digging until we hit tensors
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
|
||||||
|
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||||
|
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
|
||||||
|
# cast and send to device
|
||||||
|
return obj.to(*args, **kwargs)
|
||||||
|
elif isinstance(obj, torch.Tensor) and device is not None:
|
||||||
|
# only send to device, don't cast
|
||||||
|
return obj.to(device=device)
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
import torch # noqa
|
import torch # noqa
|
||||||
|
|
||||||
new_data = {}
|
|
||||||
device = kwargs.get("device")
|
device = kwargs.get("device")
|
||||||
# Check if the args are a device or a dtype
|
# Check if the args are a device or a dtype
|
||||||
if device is None and len(args) > 0:
|
if device is None and len(args) > 0:
|
||||||
@@ -80,21 +94,8 @@ class BatchMixFeature(BatchFeature):
|
|||||||
else:
|
else:
|
||||||
# it's something else
|
# it's something else
|
||||||
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
||||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
|
||||||
for k, v in self.items():
|
self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
|
||||||
# check if v is a floating point
|
|
||||||
if isinstance(v, list):
|
|
||||||
new_data[k] = [
|
|
||||||
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
|
|
||||||
]
|
|
||||||
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
|
|
||||||
# cast and send to device
|
|
||||||
new_data[k] = v.to(*args, **kwargs)
|
|
||||||
elif isinstance(v, torch.Tensor) and device is not None:
|
|
||||||
new_data[k] = v.to(device=device)
|
|
||||||
else:
|
|
||||||
new_data[k] = v
|
|
||||||
self.data = new_data
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -126,7 +126,6 @@ def rotate_half(x):
|
|||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||||
|
|
||||||
@@ -491,18 +490,20 @@ class PixtralVisionModel(PixtralPreTrainedModel):
|
|||||||
all tokens of all images of shape (N_toks, D)
|
all tokens of all images of shape (N_toks, D)
|
||||||
"""
|
"""
|
||||||
# pass images through initial convolution independently
|
# pass images through initial convolution independently
|
||||||
patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values]
|
if len(pixel_values) > 1:
|
||||||
|
raise ValueError("Batching/padding not supported yet!")
|
||||||
|
patch_embeds_list = [self.patch_conv(img.to(self.dtype)) for sample in pixel_values for img in sample]
|
||||||
|
|
||||||
# flatten to a single sequence
|
# flatten to a single sequence
|
||||||
patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
|
patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0)
|
||||||
patch_embeds = self.ln_pre(patch_embeds)
|
patch_embeds = self.ln_pre(patch_embeds)
|
||||||
|
|
||||||
# positional embeddings
|
# positional embeddings
|
||||||
position_ids = position_ids_in_meshgrid(
|
position_ids = position_ids_in_meshgrid(
|
||||||
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
|
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
|
position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
|
||||||
|
|
||||||
attention_mask = generate_block_attention_mask(
|
attention_mask = generate_block_attention_mask(
|
||||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from ...feature_extraction_utils import BatchFeature
|
|||||||
from ...image_utils import ImageInput, is_valid_image, load_image
|
from ...image_utils import ImageInput, is_valid_image, load_image
|
||||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
||||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
|
from ...utils import is_torch_device, is_torch_dtype, logging, requires_backends
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -66,10 +66,24 @@ class BatchMixFeature(BatchFeature):
|
|||||||
Returns:
|
Returns:
|
||||||
[`BatchFeature`]: The same instance after modification.
|
[`BatchFeature`]: The same instance after modification.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _recursive_to(obj, device, *args, **kwargs):
|
||||||
|
# Lists can be nested, so keep digging until we hit tensors
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
|
||||||
|
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||||
|
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
|
||||||
|
# cast and send to device
|
||||||
|
return obj.to(*args, **kwargs)
|
||||||
|
elif isinstance(obj, torch.Tensor) and device is not None:
|
||||||
|
# only send to device, don't cast
|
||||||
|
return obj.to(device=device)
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
import torch # noqa
|
import torch # noqa
|
||||||
|
|
||||||
new_data = {}
|
|
||||||
device = kwargs.get("device")
|
device = kwargs.get("device")
|
||||||
# Check if the args are a device or a dtype
|
# Check if the args are a device or a dtype
|
||||||
if device is None and len(args) > 0:
|
if device is None and len(args) > 0:
|
||||||
@@ -83,21 +97,8 @@ class BatchMixFeature(BatchFeature):
|
|||||||
else:
|
else:
|
||||||
# it's something else
|
# it's something else
|
||||||
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
||||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
|
||||||
for k, v in self.items():
|
self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
|
||||||
# check if v is a floating point
|
|
||||||
if isinstance(v, list):
|
|
||||||
new_data[k] = [
|
|
||||||
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
|
|
||||||
]
|
|
||||||
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
|
|
||||||
# cast and send to device
|
|
||||||
new_data[k] = v.to(*args, **kwargs)
|
|
||||||
elif isinstance(v, torch.Tensor) and device is not None:
|
|
||||||
new_data[k] = v.to(device=device)
|
|
||||||
else:
|
|
||||||
new_data[k] = v
|
|
||||||
self.data = new_data
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@@ -204,12 +205,21 @@ class PixtralProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
if is_image_or_image_url(images):
|
if is_image_or_image_url(images):
|
||||||
images = [[images]]
|
if isinstance(text, str) or isinstance(text, list) and len(text) == 1:
|
||||||
elif isinstance(images, list) and is_image_or_image_url(images[0]):
|
# If there's a single sample, the image must belong to it
|
||||||
if isinstance(text, list):
|
images = [[images]]
|
||||||
images = [[im] for im in images]
|
|
||||||
else:
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample."
|
||||||
|
)
|
||||||
|
elif isinstance(images, list) and is_image_or_image_url(images[0]):
|
||||||
|
if isinstance(text, str) or isinstance(text, list) and len(text) == 1:
|
||||||
|
# If there's a single sample, all images must belong to it
|
||||||
images = [images]
|
images = [images]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample."
|
||||||
|
)
|
||||||
elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]):
|
elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -55,6 +55,8 @@ class VideoLlavaConfig(PretrainedConfig):
|
|||||||
Sequence length of one image embedding.
|
Sequence length of one image embedding.
|
||||||
video_seq_length (`int`, *optional*, defaults to 2056):
|
video_seq_length (`int`, *optional*, defaults to 2056):
|
||||||
Sequence length of one video embedding.
|
Sequence length of one video embedding.
|
||||||
|
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use bias in the multimodal projector.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -92,6 +94,7 @@ class VideoLlavaConfig(PretrainedConfig):
|
|||||||
vision_feature_layer=-2,
|
vision_feature_layer=-2,
|
||||||
image_seq_length=256,
|
image_seq_length=256,
|
||||||
video_seq_length=2056,
|
video_seq_length=2056,
|
||||||
|
multimodal_projector_bias=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
@@ -102,6 +105,7 @@ class VideoLlavaConfig(PretrainedConfig):
|
|||||||
self.vision_feature_layer = vision_feature_layer
|
self.vision_feature_layer = vision_feature_layer
|
||||||
self.image_seq_length = image_seq_length
|
self.image_seq_length = image_seq_length
|
||||||
self.video_seq_length = video_seq_length
|
self.video_seq_length = video_seq_length
|
||||||
|
self.multimodal_projector_bias = multimodal_projector_bias
|
||||||
|
|
||||||
self.vision_config = vision_config
|
self.vision_config = vision_config
|
||||||
|
|
||||||
|
|||||||
@@ -88,10 +88,13 @@ class VideoLlavaCausalLMOutputWithPast(ModelOutput):
|
|||||||
class VideoLlavaMultiModalProjector(nn.Module):
|
class VideoLlavaMultiModalProjector(nn.Module):
|
||||||
def __init__(self, config: VideoLlavaConfig):
|
def __init__(self, config: VideoLlavaConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.linear_1 = nn.Linear(
|
||||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||||
|
)
|
||||||
self.act = ACT2FN[config.projector_hidden_act]
|
self.act = ACT2FN[config.projector_hidden_act]
|
||||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
self.linear_2 = nn.Linear(
|
||||||
|
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, image_features):
|
def forward(self, image_features):
|
||||||
hidden_states = self.linear_1(image_features)
|
hidden_states = self.linear_1(image_features)
|
||||||
|
|||||||
@@ -253,7 +253,7 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
|
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
|
||||||
] * 5
|
] * 5
|
||||||
processor.tokenizer.pad_token = "</s>"
|
processor.tokenizer.pad_token = "</s>"
|
||||||
image_inputs = [self.image_0] * 5
|
image_inputs = [[self.image_0]] * 5
|
||||||
|
|
||||||
# Make small for checking image token expansion
|
# Make small for checking image token expansion
|
||||||
processor.image_processor.size = {"longest_edge": 30}
|
processor.image_processor.size = {"longest_edge": 30}
|
||||||
|
|||||||
Reference in New Issue
Block a user