[PixtralLarge] Update Pixtral conversion script to support large format! (#34801)

* update conversion script

* update for bias again

* remove pdv

* use my dir

* Update how we initialize the tokenizer

* Convert in bfloat16

* Undo that one again

* fix config dump

* .to() was broken for BatchMixFeature

* quick debug breakpoint

* put the breakpoint in the right place

* Add a config flag for the multimodal projector bias

* Add a config flag for the multimodal projector bias

* Conversion script can load chat templates

* Indent config for comparison

* Stop clobbering the config

* Re-enable the config clobber

* Get rid of the config manual save - it has no effect!

* Handle adapter bias correctly

* Default vision transformer activation to silu

* Remove legacy processing path

* One commit with all the debug breakpoints before I delete them all, in case I need to revert

* Update conversion

* Remove vLLM debugging instrumentation

* Drop xformers

* Remove debug enumerates

* make fixup

* make fixup

* Break copied from in pixtral

* Propagate multimodal_projector_bias change

* Propagate multimodal_projector_bias change

* Remove debug device .to()

* Restore attention weights output

* Fix Pixtral test

* Drop image_seq_length

* Drop image_seq_length

* Put the legacy processing code back

* Add the bias option to the llava_next_video config

* Add the bias option to the llava_next_video config

* Make certain args required in converter

* Make certain args required in converter

* typo

* make fixup

* Reverting some dtype changes since it seems to work without them

---------

Co-authored-by: arthur@huggingface.co <arthur@ip-26-0-166-244.ec2.internal>
Co-authored-by: Matt <rocketknight1@gmail.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Arthur
2025-01-08 17:39:47 +01:00
committed by GitHub
parent 4c2c12b3de
commit 3f483beab9
16 changed files with 199 additions and 114 deletions

View File

@@ -50,6 +50,8 @@ class LlavaConfig(PretrainedConfig):
The index of the layer to select the vision feature.
image_seq_length (`int`, *optional*, defaults to 576):
Sequence length of one image embedding.
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the multimodal projector.
Example:
@@ -85,6 +87,7 @@ class LlavaConfig(PretrainedConfig):
vision_feature_select_strategy="default",
vision_feature_layer=-2,
image_seq_length=576,
multimodal_projector_bias=True,
**kwargs,
):
self.ignore_index = ignore_index
@@ -127,6 +130,7 @@ class LlavaConfig(PretrainedConfig):
text_config = CONFIG_MAPPING["llama"]()
self.text_config = text_config
self.multimodal_projector_bias = multimodal_projector_bias
super().__init__(**kwargs)

View File

@@ -86,10 +86,13 @@ class LlavaCausalLMOutputWithPast(ModelOutput):
class LlavaMultiModalProjector(nn.Module):
def __init__(self, config: LlavaConfig):
super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)

View File

@@ -55,6 +55,8 @@ class LlavaNextConfig(PretrainedConfig):
Whether the model's input and output word embeddings should be tied.
image_seq_length (`int`, *optional*, defaults to 576):
Sequence length of one image embedding.
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the multimodal projector.
Example:
@@ -92,12 +94,14 @@ class LlavaNextConfig(PretrainedConfig):
image_grid_pinpoints=None,
tie_word_embeddings=False,
image_seq_length=576,
multimodal_projector_bias=True,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.image_seq_length = image_seq_length
self.multimodal_projector_bias = multimodal_projector_bias
if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(

View File

@@ -194,10 +194,13 @@ class LlavaNextCausalLMOutputWithPast(ModelOutput):
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, config: LlavaNextConfig):
super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)

View File

@@ -44,6 +44,8 @@ class LlavaNextVideoConfig(PretrainedConfig):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
The activation function used by the multimodal projector.
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the multimodal projector.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
@@ -95,6 +97,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
ignore_index=-100,
image_token_index=32001,
projector_hidden_act="gelu",
multimodal_projector_bias=True,
vision_feature_select_strategy="default",
vision_feature_layer=-2,
image_grid_pinpoints=None,
@@ -114,6 +117,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.multimodal_projector_bias = multimodal_projector_bias
if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(

View File

@@ -179,10 +179,13 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
class LlavaNextVideoMultiModalProjector(nn.Module):
def __init__(self, config: LlavaNextVideoConfig):
super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)

View File

@@ -58,6 +58,8 @@ class LlavaNextVideoConfig(PretrainedConfig):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
The activation function used by the multimodal projector.
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the multimodal projector.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
@@ -109,6 +111,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
ignore_index=-100,
image_token_index=32001,
projector_hidden_act="gelu",
multimodal_projector_bias=True,
vision_feature_select_strategy="default",
vision_feature_layer=-2,
image_grid_pinpoints=None,
@@ -128,6 +131,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.multimodal_projector_bias = multimodal_projector_bias
if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(

View File

@@ -58,6 +58,8 @@ class LlavaOnevisionConfig(PretrainedConfig):
of the form `(height, width)`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the multimodal projector.
Example:
@@ -95,11 +97,13 @@ class LlavaOnevisionConfig(PretrainedConfig):
vision_aspect_ratio="anyres_max_9",
image_grid_pinpoints=None,
tie_word_embeddings=False,
multimodal_projector_bias=True,
**kwargs,
):
self.image_token_index = image_token_index
self.video_token_index = video_token_index
self.projector_hidden_act = projector_hidden_act
self.multimodal_projector_bias = multimodal_projector_bias
if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(

View File

@@ -201,10 +201,13 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput):
class LlavaOnevisionMultiModalProjector(nn.Module):
def __init__(self, config: LlavaOnevisionConfig):
super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import regex as re
import torch
@@ -27,7 +29,6 @@ from transformers import (
PixtralImageProcessor,
PixtralProcessor,
PixtralVisionConfig,
PreTrainedTokenizerFast,
)
from transformers.convert_slow_tokenizer import bytes_to_unicode
@@ -156,29 +157,18 @@ class MistralConverter:
return tokenizer
def convert_mistral_tokenizer():
model_name = "mistralai/Pixtral-12B-2409"
def convert_mistral_tokenizer(model_file):
from transformers import LlamaTokenizer
tokenizer = MistralTokenizer.from_model(model_name)
vocab = tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial
all_special = [
token.value if hasattr(token, "value") else token
for token in tokenizer.instruct_tokenizer.tokenizer._all_special_tokens
]
specials_tokens = {token: all_special.index(token) for token in all_special}
specials_tokens.update(vocab)
vocab = specials_tokens
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(),
bos_token="<s>",
unk_token="<unk>",
eos_token="</s>",
)
tokenizer.model_input_names = ["input_ids", "attention_mask"]
return tokenizer
mistral_tokenizer = MistralTokenizer.from_file(model_file)
vocab = mistral_tokenizer.instruct_tokenizer.tokenizer.vocab()
control_token_ids = mistral_tokenizer.instruct_tokenizer.tokenizer._control_tokens
all_special = [vocab[id] for id in control_token_ids]
hf_tokenizer = LlamaTokenizer(model_file)
# Do I need to exclude tokens that are already special?
hf_tokenizer.add_special_tokens({"additional_special_tokens": all_special})
hf_tokenizer.model_input_names = ["input_ids", "attention_mask"]
return hf_tokenizer
def permute_for_rope(value, n_heads, config):
@@ -187,7 +177,7 @@ def permute_for_rope(value, n_heads, config):
return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
def convert_dictionnary(original_state_dict, vision_config, text_config):
def convert_dictionary(original_state_dict, vision_config, text_config):
new_dict = {}
all_keys = "\n" + "\n".join(original_state_dict.keys())
@@ -208,7 +198,6 @@ def convert_dictionnary(original_state_dict, vision_config, text_config):
num_attention_heads = _config.num_attention_heads
if "k_proj" in new_key:
num_attention_heads = _config.num_key_value_heads
# convert the text model (basically mistral model)
if "q_proj" in new_key or "k_proj" in new_key:
value = permute_for_rope(value, num_attention_heads, _config)
@@ -217,7 +206,35 @@ def convert_dictionnary(original_state_dict, vision_config, text_config):
return new_dict
MISTRAL_CONFIG_MAPPING = {
"dim": "hidden_size",
"hidden_dim": "intermediate_size",
"n_kv_heads": "num_key_value_heads",
"n_heads": "num_attention_heads",
"n_layers": "num_hidden_layers",
}
def convert_mistral_model(input_dir, output_dir):
vision_config = {}
if os.path.isfile(f"{input_dir}/params.json"):
with open(f"{input_dir}/params.json") as f:
param_json = json.load(f)
vision_config = param_json.pop("vision_encoder")
for k, v in MISTRAL_CONFIG_MAPPING.items():
value = param_json.pop(k)
param_json[v] = value
if "hidden_act" not in vision_config:
vision_config["hidden_act"] = "silu"
text_config = MistralConfig(
**param_json,
hidden_act="silu",
sliding_window=None,
tie_word_embeddings=False,
is_composition=True,
rms_norm_eps=1e-5,
)
else:
text_config = MistralConfig(
attention_dropout=0.0,
bos_token_id=1,
@@ -238,8 +255,8 @@ def convert_mistral_model(input_dir, output_dir):
tie_word_embeddings=False,
vocab_size=131072,
)
vision_config = PixtralVisionConfig()
adapter_bias = vision_config.pop("adapter_bias", True)
vision_config = PixtralVisionConfig(**vision_config)
config = LlavaConfig(
vision_config,
text_config,
@@ -247,38 +264,55 @@ def convert_mistral_model(input_dir, output_dir):
image_token_index=10,
vision_feature_select_strategy="full",
image_seq_length=1,
multimodal_projector_bias=adapter_bias,
)
config.architectures = ["LlavaForConditionalGeneration"]
config.save_pretrained(output_dir)
full_original_state_dict = {}
safetensors_files = sorted([file for file in os.listdir(input_dir) if file.endswith(".safetensors")])
if len(safetensors_files) == 1:
full_original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors")
else:
for file in safetensors_files:
loaded_dict = safe_load_file(f"{input_dir}/{file}")
full_original_state_dict.update(loaded_dict)
original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors")
new_dict = convert_dictionnary(original_state_dict, vision_config, text_config)
new_dict = convert_dictionary(full_original_state_dict, vision_config, text_config)
with torch.device("meta"):
model = LlavaForConditionalGeneration(config)
model.load_state_dict(new_dict, strict=True, assign=True)
model.save_pretrained(output_dir)
tokenizer = convert_mistral_tokenizer()
image_processor = PixtralImageProcessor()
processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor, image_token="[IMG]")
processor.save_pretrained(output_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
required=True,
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
required=True,
)
parser.add_argument(
"--tokenizer_file", help="Location of the specific tokenizer model file to use.", required=True
)
parser.add_argument(
"--chat_template_file",
help="Optional file containing a raw chat template. Will be set as the processor's chat template.",
required=False,
)
args = parser.parse_args()
convert_mistral_model(args.input_dir, args.output_dir)
tokenizer = convert_mistral_tokenizer(args.tokenizer_file)
image_processor = PixtralImageProcessor()
processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor, image_token="[IMG]")
if args.chat_template_file:
processor.chat_template = open(args.chat_template_file).read()
processor.save_pretrained(args.output_dir)
if __name__ == "__main__":

View File

@@ -37,7 +37,7 @@ from ...image_utils import (
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, is_vision_available, logging
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_vision_available, logging
from ...utils.import_utils import requires_backends
@@ -63,10 +63,24 @@ class BatchMixFeature(BatchFeature):
Returns:
[`BatchFeature`]: The same instance after modification.
"""
def _recursive_to(obj, device, *args, **kwargs):
# Lists can be nested, so keep digging until we hit tensors
if isinstance(obj, list):
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
# cast and send to device
return obj.to(*args, **kwargs)
elif isinstance(obj, torch.Tensor) and device is not None:
# only send to device, don't cast
return obj.to(device=device)
else:
return obj
requires_backends(self, ["torch"])
import torch # noqa
new_data = {}
device = kwargs.get("device")
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
@@ -80,21 +94,8 @@ class BatchMixFeature(BatchFeature):
else:
# it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
for k, v in self.items():
# check if v is a floating point
if isinstance(v, list):
new_data[k] = [
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
]
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif isinstance(v, torch.Tensor) and device is not None:
new_data[k] = v.to(device=device)
else:
new_data[k] = v
self.data = new_data
self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
return self

View File

@@ -126,7 +126,6 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -491,18 +490,20 @@ class PixtralVisionModel(PixtralPreTrainedModel):
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values]
if len(pixel_values) > 1:
raise ValueError("Batching/padding not supported yet!")
patch_embeds_list = [self.patch_conv(img.to(self.dtype)) for sample in pixel_values for img in sample]
# flatten to a single sequence
patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
position_ids = position_ids_in_meshgrid(
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
).to(self.device)
position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
attention_mask = generate_block_attention_mask(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
)

View File

@@ -22,7 +22,7 @@ from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
from ...utils import is_torch_device, is_torch_dtype, logging, requires_backends
logger = logging.get_logger(__name__)
@@ -66,10 +66,24 @@ class BatchMixFeature(BatchFeature):
Returns:
[`BatchFeature`]: The same instance after modification.
"""
def _recursive_to(obj, device, *args, **kwargs):
# Lists can be nested, so keep digging until we hit tensors
if isinstance(obj, list):
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
# cast and send to device
return obj.to(*args, **kwargs)
elif isinstance(obj, torch.Tensor) and device is not None:
# only send to device, don't cast
return obj.to(device=device)
else:
return obj
requires_backends(self, ["torch"])
import torch # noqa
new_data = {}
device = kwargs.get("device")
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
@@ -83,21 +97,8 @@ class BatchMixFeature(BatchFeature):
else:
# it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
for k, v in self.items():
# check if v is a floating point
if isinstance(v, list):
new_data[k] = [
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
]
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif isinstance(v, torch.Tensor) and device is not None:
new_data[k] = v.to(device=device)
else:
new_data[k] = v
self.data = new_data
self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
return self
@@ -204,12 +205,21 @@ class PixtralProcessor(ProcessorMixin):
if images is not None:
if is_image_or_image_url(images):
if isinstance(text, str) or isinstance(text, list) and len(text) == 1:
# If there's a single sample, the image must belong to it
images = [[images]]
elif isinstance(images, list) and is_image_or_image_url(images[0]):
if isinstance(text, list):
images = [[im] for im in images]
else:
raise ValueError(
"You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample."
)
elif isinstance(images, list) and is_image_or_image_url(images[0]):
if isinstance(text, str) or isinstance(text, list) and len(text) == 1:
# If there's a single sample, all images must belong to it
images = [images]
else:
raise ValueError(
"You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample."
)
elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]):
pass
else:

View File

@@ -55,6 +55,8 @@ class VideoLlavaConfig(PretrainedConfig):
Sequence length of one image embedding.
video_seq_length (`int`, *optional*, defaults to 2056):
Sequence length of one video embedding.
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the multimodal projector.
Example:
@@ -92,6 +94,7 @@ class VideoLlavaConfig(PretrainedConfig):
vision_feature_layer=-2,
image_seq_length=256,
video_seq_length=2056,
multimodal_projector_bias=True,
**kwargs,
):
self.ignore_index = ignore_index
@@ -102,6 +105,7 @@ class VideoLlavaConfig(PretrainedConfig):
self.vision_feature_layer = vision_feature_layer
self.image_seq_length = image_seq_length
self.video_seq_length = video_seq_length
self.multimodal_projector_bias = multimodal_projector_bias
self.vision_config = vision_config

View File

@@ -88,10 +88,13 @@ class VideoLlavaCausalLMOutputWithPast(ModelOutput):
class VideoLlavaMultiModalProjector(nn.Module):
def __init__(self, config: VideoLlavaConfig):
super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)

View File

@@ -253,7 +253,7 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
] * 5
processor.tokenizer.pad_token = "</s>"
image_inputs = [self.image_0] * 5
image_inputs = [[self.image_0]] * 5
# Make small for checking image token expansion
processor.image_processor.size = {"longest_edge": 30}