[PixtralLarge] Update Pixtral conversion script to support large format! (#34801)

* update conversion script

* update for bias again

* remove pdv

* use my dir

* Update how we initialize the tokenizer

* Convert in bfloat16

* Undo that one again

* fix config dump

* .to() was broken for BatchMixFeature

* quick debug breakpoint

* put the breakpoint in the right place

* Add a config flag for the multimodal projector bias

* Add a config flag for the multimodal projector bias

* Conversion script can load chat templates

* Indent config for comparison

* Stop clobbering the config

* Re-enable the config clobber

* Get rid of the config manual save - it has no effect!

* Handle adapter bias correctly

* Default vision transformer activation to silu

* Remove legacy processing path

* One commit with all the debug breakpoints before I delete them all, in case I need to revert

* Update conversion

* Remove vLLM debugging instrumentation

* Drop xformers

* Remove debug enumerates

* make fixup

* make fixup

* Break copied from in pixtral

* Propagate multimodal_projector_bias change

* Propagate multimodal_projector_bias change

* Remove debug device .to()

* Restore attention weights output

* Fix Pixtral test

* Drop image_seq_length

* Drop image_seq_length

* Put the legacy processing code back

* Add the bias option to the llava_next_video config

* Add the bias option to the llava_next_video config

* Make certain args required in converter

* Make certain args required in converter

* typo

* make fixup

* Reverting some dtype changes since it seems to work without them

---------

Co-authored-by: arthur@huggingface.co <arthur@ip-26-0-166-244.ec2.internal>
Co-authored-by: Matt <rocketknight1@gmail.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Arthur
2025-01-08 17:39:47 +01:00
committed by GitHub
parent 4c2c12b3de
commit 3f483beab9
16 changed files with 199 additions and 114 deletions

View File

@@ -50,6 +50,8 @@ class LlavaConfig(PretrainedConfig):
The index of the layer to select the vision feature. The index of the layer to select the vision feature.
image_seq_length (`int`, *optional*, defaults to 576): image_seq_length (`int`, *optional*, defaults to 576):
Sequence length of one image embedding. Sequence length of one image embedding.
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the multimodal projector.
Example: Example:
@@ -85,6 +87,7 @@ class LlavaConfig(PretrainedConfig):
vision_feature_select_strategy="default", vision_feature_select_strategy="default",
vision_feature_layer=-2, vision_feature_layer=-2,
image_seq_length=576, image_seq_length=576,
multimodal_projector_bias=True,
**kwargs, **kwargs,
): ):
self.ignore_index = ignore_index self.ignore_index = ignore_index
@@ -127,6 +130,7 @@ class LlavaConfig(PretrainedConfig):
text_config = CONFIG_MAPPING["llama"]() text_config = CONFIG_MAPPING["llama"]()
self.text_config = text_config self.text_config = text_config
self.multimodal_projector_bias = multimodal_projector_bias
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@@ -86,10 +86,13 @@ class LlavaCausalLMOutputWithPast(ModelOutput):
class LlavaMultiModalProjector(nn.Module): class LlavaMultiModalProjector(nn.Module):
def __init__(self, config: LlavaConfig): def __init__(self, config: LlavaConfig):
super().__init__() super().__init__()
self.linear_1 = nn.Linear(
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
self.act = ACT2FN[config.projector_hidden_act] self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
def forward(self, image_features): def forward(self, image_features):
hidden_states = self.linear_1(image_features) hidden_states = self.linear_1(image_features)

View File

@@ -55,6 +55,8 @@ class LlavaNextConfig(PretrainedConfig):
Whether the model's input and output word embeddings should be tied. Whether the model's input and output word embeddings should be tied.
image_seq_length (`int`, *optional*, defaults to 576): image_seq_length (`int`, *optional*, defaults to 576):
Sequence length of one image embedding. Sequence length of one image embedding.
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the multimodal projector.
Example: Example:
@@ -92,12 +94,14 @@ class LlavaNextConfig(PretrainedConfig):
image_grid_pinpoints=None, image_grid_pinpoints=None,
tie_word_embeddings=False, tie_word_embeddings=False,
image_seq_length=576, image_seq_length=576,
multimodal_projector_bias=True,
**kwargs, **kwargs,
): ):
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.image_token_index = image_token_index self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act self.projector_hidden_act = projector_hidden_act
self.image_seq_length = image_seq_length self.image_seq_length = image_seq_length
self.multimodal_projector_bias = multimodal_projector_bias
if vision_feature_select_strategy not in ["default", "full"]: if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError( raise ValueError(

View File

@@ -194,10 +194,13 @@ class LlavaNextCausalLMOutputWithPast(ModelOutput):
class LlavaNextMultiModalProjector(nn.Module): class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, config: LlavaNextConfig): def __init__(self, config: LlavaNextConfig):
super().__init__() super().__init__()
self.linear_1 = nn.Linear(
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
self.act = ACT2FN[config.projector_hidden_act] self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
def forward(self, image_features): def forward(self, image_features):
hidden_states = self.linear_1(image_features) hidden_states = self.linear_1(image_features)

View File

@@ -44,6 +44,8 @@ class LlavaNextVideoConfig(PretrainedConfig):
The image token index to encode the image prompt. The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
The activation function used by the multimodal projector. The activation function used by the multimodal projector.
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the multimodal projector.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone. The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
@@ -95,6 +97,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
ignore_index=-100, ignore_index=-100,
image_token_index=32001, image_token_index=32001,
projector_hidden_act="gelu", projector_hidden_act="gelu",
multimodal_projector_bias=True,
vision_feature_select_strategy="default", vision_feature_select_strategy="default",
vision_feature_layer=-2, vision_feature_layer=-2,
image_grid_pinpoints=None, image_grid_pinpoints=None,
@@ -114,6 +117,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.image_token_index = image_token_index self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act self.projector_hidden_act = projector_hidden_act
self.multimodal_projector_bias = multimodal_projector_bias
if vision_feature_select_strategy not in ["default", "full"]: if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError( raise ValueError(

View File

@@ -179,10 +179,13 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
class LlavaNextVideoMultiModalProjector(nn.Module): class LlavaNextVideoMultiModalProjector(nn.Module):
def __init__(self, config: LlavaNextVideoConfig): def __init__(self, config: LlavaNextVideoConfig):
super().__init__() super().__init__()
self.linear_1 = nn.Linear(
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
self.act = ACT2FN[config.projector_hidden_act] self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
def forward(self, image_features): def forward(self, image_features):
hidden_states = self.linear_1(image_features) hidden_states = self.linear_1(image_features)

View File

@@ -58,6 +58,8 @@ class LlavaNextVideoConfig(PretrainedConfig):
The image token index to encode the image prompt. The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
The activation function used by the multimodal projector. The activation function used by the multimodal projector.
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the multimodal projector.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone. The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
@@ -109,6 +111,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
ignore_index=-100, ignore_index=-100,
image_token_index=32001, image_token_index=32001,
projector_hidden_act="gelu", projector_hidden_act="gelu",
multimodal_projector_bias=True,
vision_feature_select_strategy="default", vision_feature_select_strategy="default",
vision_feature_layer=-2, vision_feature_layer=-2,
image_grid_pinpoints=None, image_grid_pinpoints=None,
@@ -128,6 +131,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.image_token_index = image_token_index self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act self.projector_hidden_act = projector_hidden_act
self.multimodal_projector_bias = multimodal_projector_bias
if vision_feature_select_strategy not in ["default", "full"]: if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError( raise ValueError(

View File

@@ -58,6 +58,8 @@ class LlavaOnevisionConfig(PretrainedConfig):
of the form `(height, width)`. of the form `(height, width)`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`): tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied. Whether the model's input and output word embeddings should be tied.
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the multimodal projector.
Example: Example:
@@ -95,11 +97,13 @@ class LlavaOnevisionConfig(PretrainedConfig):
vision_aspect_ratio="anyres_max_9", vision_aspect_ratio="anyres_max_9",
image_grid_pinpoints=None, image_grid_pinpoints=None,
tie_word_embeddings=False, tie_word_embeddings=False,
multimodal_projector_bias=True,
**kwargs, **kwargs,
): ):
self.image_token_index = image_token_index self.image_token_index = image_token_index
self.video_token_index = video_token_index self.video_token_index = video_token_index
self.projector_hidden_act = projector_hidden_act self.projector_hidden_act = projector_hidden_act
self.multimodal_projector_bias = multimodal_projector_bias
if vision_feature_select_strategy not in ["default", "full"]: if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError( raise ValueError(

View File

@@ -201,10 +201,13 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput):
class LlavaOnevisionMultiModalProjector(nn.Module): class LlavaOnevisionMultiModalProjector(nn.Module):
def __init__(self, config: LlavaOnevisionConfig): def __init__(self, config: LlavaOnevisionConfig):
super().__init__() super().__init__()
self.linear_1 = nn.Linear(
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
self.act = ACT2FN[config.projector_hidden_act] self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
def forward(self, image_features): def forward(self, image_features):
hidden_states = self.linear_1(image_features) hidden_states = self.linear_1(image_features)

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import json
import os
import regex as re import regex as re
import torch import torch
@@ -27,7 +29,6 @@ from transformers import (
PixtralImageProcessor, PixtralImageProcessor,
PixtralProcessor, PixtralProcessor,
PixtralVisionConfig, PixtralVisionConfig,
PreTrainedTokenizerFast,
) )
from transformers.convert_slow_tokenizer import bytes_to_unicode from transformers.convert_slow_tokenizer import bytes_to_unicode
@@ -156,29 +157,18 @@ class MistralConverter:
return tokenizer return tokenizer
def convert_mistral_tokenizer(): def convert_mistral_tokenizer(model_file):
model_name = "mistralai/Pixtral-12B-2409" from transformers import LlamaTokenizer
tokenizer = MistralTokenizer.from_model(model_name) mistral_tokenizer = MistralTokenizer.from_file(model_file)
vocab = mistral_tokenizer.instruct_tokenizer.tokenizer.vocab()
vocab = tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial control_token_ids = mistral_tokenizer.instruct_tokenizer.tokenizer._control_tokens
all_special = [ all_special = [vocab[id] for id in control_token_ids]
token.value if hasattr(token, "value") else token hf_tokenizer = LlamaTokenizer(model_file)
for token in tokenizer.instruct_tokenizer.tokenizer._all_special_tokens # Do I need to exclude tokens that are already special?
] hf_tokenizer.add_special_tokens({"additional_special_tokens": all_special})
specials_tokens = {token: all_special.index(token) for token in all_special} hf_tokenizer.model_input_names = ["input_ids", "attention_mask"]
specials_tokens.update(vocab) return hf_tokenizer
vocab = specials_tokens
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(),
bos_token="<s>",
unk_token="<unk>",
eos_token="</s>",
)
tokenizer.model_input_names = ["input_ids", "attention_mask"]
return tokenizer
def permute_for_rope(value, n_heads, config): def permute_for_rope(value, n_heads, config):
@@ -187,7 +177,7 @@ def permute_for_rope(value, n_heads, config):
return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
def convert_dictionnary(original_state_dict, vision_config, text_config): def convert_dictionary(original_state_dict, vision_config, text_config):
new_dict = {} new_dict = {}
all_keys = "\n" + "\n".join(original_state_dict.keys()) all_keys = "\n" + "\n".join(original_state_dict.keys())
@@ -208,7 +198,6 @@ def convert_dictionnary(original_state_dict, vision_config, text_config):
num_attention_heads = _config.num_attention_heads num_attention_heads = _config.num_attention_heads
if "k_proj" in new_key: if "k_proj" in new_key:
num_attention_heads = _config.num_key_value_heads num_attention_heads = _config.num_key_value_heads
# convert the text model (basically mistral model)
if "q_proj" in new_key or "k_proj" in new_key: if "q_proj" in new_key or "k_proj" in new_key:
value = permute_for_rope(value, num_attention_heads, _config) value = permute_for_rope(value, num_attention_heads, _config)
@@ -217,29 +206,57 @@ def convert_dictionnary(original_state_dict, vision_config, text_config):
return new_dict return new_dict
def convert_mistral_model(input_dir, output_dir): MISTRAL_CONFIG_MAPPING = {
text_config = MistralConfig( "dim": "hidden_size",
attention_dropout=0.0, "hidden_dim": "intermediate_size",
bos_token_id=1, "n_kv_heads": "num_key_value_heads",
eos_token_id=2, "n_heads": "num_attention_heads",
head_dim=128, "n_layers": "num_hidden_layers",
hidden_act="silu", }
hidden_size=5120,
initializer_range=0.02,
intermediate_size=14336,
max_position_embeddings=1024000,
model_type="mistral",
num_attention_heads=32,
num_hidden_layers=40,
num_key_value_heads=8,
rms_norm_eps=1e-05,
rope_theta=1000000000.0,
sliding_window=None,
tie_word_embeddings=False,
vocab_size=131072,
)
vision_config = PixtralVisionConfig()
def convert_mistral_model(input_dir, output_dir):
vision_config = {}
if os.path.isfile(f"{input_dir}/params.json"):
with open(f"{input_dir}/params.json") as f:
param_json = json.load(f)
vision_config = param_json.pop("vision_encoder")
for k, v in MISTRAL_CONFIG_MAPPING.items():
value = param_json.pop(k)
param_json[v] = value
if "hidden_act" not in vision_config:
vision_config["hidden_act"] = "silu"
text_config = MistralConfig(
**param_json,
hidden_act="silu",
sliding_window=None,
tie_word_embeddings=False,
is_composition=True,
rms_norm_eps=1e-5,
)
else:
text_config = MistralConfig(
attention_dropout=0.0,
bos_token_id=1,
eos_token_id=2,
head_dim=128,
hidden_act="silu",
hidden_size=5120,
initializer_range=0.02,
intermediate_size=14336,
max_position_embeddings=1024000,
model_type="mistral",
num_attention_heads=32,
num_hidden_layers=40,
num_key_value_heads=8,
rms_norm_eps=1e-05,
rope_theta=1000000000.0,
sliding_window=None,
tie_word_embeddings=False,
vocab_size=131072,
)
adapter_bias = vision_config.pop("adapter_bias", True)
vision_config = PixtralVisionConfig(**vision_config)
config = LlavaConfig( config = LlavaConfig(
vision_config, vision_config,
text_config, text_config,
@@ -247,38 +264,55 @@ def convert_mistral_model(input_dir, output_dir):
image_token_index=10, image_token_index=10,
vision_feature_select_strategy="full", vision_feature_select_strategy="full",
image_seq_length=1, image_seq_length=1,
multimodal_projector_bias=adapter_bias,
) )
config.architectures = ["LlavaForConditionalGeneration"] config.architectures = ["LlavaForConditionalGeneration"]
config.save_pretrained(output_dir) config.save_pretrained(output_dir)
full_original_state_dict = {}
safetensors_files = sorted([file for file in os.listdir(input_dir) if file.endswith(".safetensors")])
if len(safetensors_files) == 1:
full_original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors")
else:
for file in safetensors_files:
loaded_dict = safe_load_file(f"{input_dir}/{file}")
full_original_state_dict.update(loaded_dict)
original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors") new_dict = convert_dictionary(full_original_state_dict, vision_config, text_config)
new_dict = convert_dictionnary(original_state_dict, vision_config, text_config)
with torch.device("meta"): with torch.device("meta"):
model = LlavaForConditionalGeneration(config) model = LlavaForConditionalGeneration(config)
model.load_state_dict(new_dict, strict=True, assign=True) model.load_state_dict(new_dict, strict=True, assign=True)
model.save_pretrained(output_dir) model.save_pretrained(output_dir)
tokenizer = convert_mistral_tokenizer()
image_processor = PixtralImageProcessor()
processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor, image_token="[IMG]")
processor.save_pretrained(output_dir)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--input_dir", "--input_dir",
help="Location of LLaMA weights, which contains tokenizer.model and model folders", help="Location of LLaMA weights, which contains tokenizer.model and model folders",
required=True,
) )
parser.add_argument( parser.add_argument(
"--output_dir", "--output_dir",
help="Location to write HF model and tokenizer", help="Location to write HF model and tokenizer",
required=True,
)
parser.add_argument(
"--tokenizer_file", help="Location of the specific tokenizer model file to use.", required=True
)
parser.add_argument(
"--chat_template_file",
help="Optional file containing a raw chat template. Will be set as the processor's chat template.",
required=False,
) )
args = parser.parse_args() args = parser.parse_args()
convert_mistral_model(args.input_dir, args.output_dir) convert_mistral_model(args.input_dir, args.output_dir)
tokenizer = convert_mistral_tokenizer(args.tokenizer_file)
image_processor = PixtralImageProcessor()
processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor, image_token="[IMG]")
if args.chat_template_file:
processor.chat_template = open(args.chat_template_file).read()
processor.save_pretrained(args.output_dir)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -37,7 +37,7 @@ from ...image_utils import (
validate_kwargs, validate_kwargs,
validate_preprocess_arguments, validate_preprocess_arguments,
) )
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, is_vision_available, logging from ...utils import TensorType, is_torch_device, is_torch_dtype, is_vision_available, logging
from ...utils.import_utils import requires_backends from ...utils.import_utils import requires_backends
@@ -63,10 +63,24 @@ class BatchMixFeature(BatchFeature):
Returns: Returns:
[`BatchFeature`]: The same instance after modification. [`BatchFeature`]: The same instance after modification.
""" """
def _recursive_to(obj, device, *args, **kwargs):
# Lists can be nested, so keep digging until we hit tensors
if isinstance(obj, list):
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
# cast and send to device
return obj.to(*args, **kwargs)
elif isinstance(obj, torch.Tensor) and device is not None:
# only send to device, don't cast
return obj.to(device=device)
else:
return obj
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
import torch # noqa import torch # noqa
new_data = {}
device = kwargs.get("device") device = kwargs.get("device")
# Check if the args are a device or a dtype # Check if the args are a device or a dtype
if device is None and len(args) > 0: if device is None and len(args) > 0:
@@ -80,21 +94,8 @@ class BatchMixFeature(BatchFeature):
else: else:
# it's something else # it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
for k, v in self.items(): self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
# check if v is a floating point
if isinstance(v, list):
new_data[k] = [
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
]
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif isinstance(v, torch.Tensor) and device is not None:
new_data[k] = v.to(device=device)
else:
new_data[k] = v
self.data = new_data
return self return self

View File

@@ -126,7 +126,6 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
@@ -491,18 +490,20 @@ class PixtralVisionModel(PixtralPreTrainedModel):
all tokens of all images of shape (N_toks, D) all tokens of all images of shape (N_toks, D)
""" """
# pass images through initial convolution independently # pass images through initial convolution independently
patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values] if len(pixel_values) > 1:
raise ValueError("Batching/padding not supported yet!")
patch_embeds_list = [self.patch_conv(img.to(self.dtype)) for sample in pixel_values for img in sample]
# flatten to a single sequence # flatten to a single sequence
patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0)
patch_embeds = self.ln_pre(patch_embeds) patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings # positional embeddings
position_ids = position_ids_in_meshgrid( position_ids = position_ids_in_meshgrid(
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
).to(self.device) ).to(self.device)
position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
attention_mask = generate_block_attention_mask( attention_mask = generate_block_attention_mask(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
) )

View File

@@ -22,7 +22,7 @@ from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends from ...utils import is_torch_device, is_torch_dtype, logging, requires_backends
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -66,10 +66,24 @@ class BatchMixFeature(BatchFeature):
Returns: Returns:
[`BatchFeature`]: The same instance after modification. [`BatchFeature`]: The same instance after modification.
""" """
def _recursive_to(obj, device, *args, **kwargs):
# Lists can be nested, so keep digging until we hit tensors
if isinstance(obj, list):
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
# cast and send to device
return obj.to(*args, **kwargs)
elif isinstance(obj, torch.Tensor) and device is not None:
# only send to device, don't cast
return obj.to(device=device)
else:
return obj
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
import torch # noqa import torch # noqa
new_data = {}
device = kwargs.get("device") device = kwargs.get("device")
# Check if the args are a device or a dtype # Check if the args are a device or a dtype
if device is None and len(args) > 0: if device is None and len(args) > 0:
@@ -83,21 +97,8 @@ class BatchMixFeature(BatchFeature):
else: else:
# it's something else # it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
for k, v in self.items(): self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
# check if v is a floating point
if isinstance(v, list):
new_data[k] = [
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
]
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif isinstance(v, torch.Tensor) and device is not None:
new_data[k] = v.to(device=device)
else:
new_data[k] = v
self.data = new_data
return self return self
@@ -204,12 +205,21 @@ class PixtralProcessor(ProcessorMixin):
if images is not None: if images is not None:
if is_image_or_image_url(images): if is_image_or_image_url(images):
images = [[images]] if isinstance(text, str) or isinstance(text, list) and len(text) == 1:
elif isinstance(images, list) and is_image_or_image_url(images[0]): # If there's a single sample, the image must belong to it
if isinstance(text, list): images = [[images]]
images = [[im] for im in images]
else: else:
raise ValueError(
"You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample."
)
elif isinstance(images, list) and is_image_or_image_url(images[0]):
if isinstance(text, str) or isinstance(text, list) and len(text) == 1:
# If there's a single sample, all images must belong to it
images = [images] images = [images]
else:
raise ValueError(
"You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample."
)
elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]): elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]):
pass pass
else: else:

View File

@@ -55,6 +55,8 @@ class VideoLlavaConfig(PretrainedConfig):
Sequence length of one image embedding. Sequence length of one image embedding.
video_seq_length (`int`, *optional*, defaults to 2056): video_seq_length (`int`, *optional*, defaults to 2056):
Sequence length of one video embedding. Sequence length of one video embedding.
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the multimodal projector.
Example: Example:
@@ -92,6 +94,7 @@ class VideoLlavaConfig(PretrainedConfig):
vision_feature_layer=-2, vision_feature_layer=-2,
image_seq_length=256, image_seq_length=256,
video_seq_length=2056, video_seq_length=2056,
multimodal_projector_bias=True,
**kwargs, **kwargs,
): ):
self.ignore_index = ignore_index self.ignore_index = ignore_index
@@ -102,6 +105,7 @@ class VideoLlavaConfig(PretrainedConfig):
self.vision_feature_layer = vision_feature_layer self.vision_feature_layer = vision_feature_layer
self.image_seq_length = image_seq_length self.image_seq_length = image_seq_length
self.video_seq_length = video_seq_length self.video_seq_length = video_seq_length
self.multimodal_projector_bias = multimodal_projector_bias
self.vision_config = vision_config self.vision_config = vision_config

View File

@@ -88,10 +88,13 @@ class VideoLlavaCausalLMOutputWithPast(ModelOutput):
class VideoLlavaMultiModalProjector(nn.Module): class VideoLlavaMultiModalProjector(nn.Module):
def __init__(self, config: VideoLlavaConfig): def __init__(self, config: VideoLlavaConfig):
super().__init__() super().__init__()
self.linear_1 = nn.Linear(
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
self.act = ACT2FN[config.projector_hidden_act] self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
def forward(self, image_features): def forward(self, image_features):
hidden_states = self.linear_1(image_features) hidden_states = self.linear_1(image_features)

View File

@@ -253,7 +253,7 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:", "USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
] * 5 ] * 5
processor.tokenizer.pad_token = "</s>" processor.tokenizer.pad_token = "</s>"
image_inputs = [self.image_0] * 5 image_inputs = [[self.image_0]] * 5
# Make small for checking image token expansion # Make small for checking image token expansion
processor.image_processor.size = {"longest_edge": 30} processor.image_processor.size = {"longest_edge": 30}