From 3f483beab9076705cf3a900c20837e7555303c3d Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 8 Jan 2025 17:39:47 +0100 Subject: [PATCH] [`PixtralLarge`] Update Pixtral conversion script to support large format! (#34801) * update conversion script * update for bias again * remove pdv * use my dir * Update how we initialize the tokenizer * Convert in bfloat16 * Undo that one again * fix config dump * .to() was broken for BatchMixFeature * quick debug breakpoint * put the breakpoint in the right place * Add a config flag for the multimodal projector bias * Add a config flag for the multimodal projector bias * Conversion script can load chat templates * Indent config for comparison * Stop clobbering the config * Re-enable the config clobber * Get rid of the config manual save - it has no effect! * Handle adapter bias correctly * Default vision transformer activation to silu * Remove legacy processing path * One commit with all the debug breakpoints before I delete them all, in case I need to revert * Update conversion * Remove vLLM debugging instrumentation * Drop xformers * Remove debug enumerates * make fixup * make fixup * Break copied from in pixtral * Propagate multimodal_projector_bias change * Propagate multimodal_projector_bias change * Remove debug device .to() * Restore attention weights output * Fix Pixtral test * Drop image_seq_length * Drop image_seq_length * Put the legacy processing code back * Add the bias option to the llava_next_video config * Add the bias option to the llava_next_video config * Make certain args required in converter * Make certain args required in converter * typo * make fixup * Reverting some dtype changes since it seems to work without them --------- Co-authored-by: arthur@huggingface.co Co-authored-by: Matt Co-authored-by: Matt --- .../models/llava/configuration_llava.py | 4 + .../models/llava/modeling_llava.py | 9 +- .../llava_next/configuration_llava_next.py | 4 + .../models/llava_next/modeling_llava_next.py | 9 +- .../configuration_llava_next_video.py | 4 + .../modeling_llava_next_video.py | 9 +- .../modular_llava_next_video.py | 4 + .../configuration_llava_onevision.py | 4 + .../modeling_llava_onevision.py | 9 +- .../pixtral/convert_pixtral_weights_to_hf.py | 146 +++++++++++------- .../pixtral/image_processing_pixtral.py | 35 +++-- .../models/pixtral/modeling_pixtral.py | 9 +- .../models/pixtral/processing_pixtral.py | 52 ++++--- .../video_llava/configuration_video_llava.py | 4 + .../video_llava/modeling_video_llava.py | 9 +- .../models/pixtral/test_processor_pixtral.py | 2 +- 16 files changed, 199 insertions(+), 114 deletions(-) diff --git a/src/transformers/models/llava/configuration_llava.py b/src/transformers/models/llava/configuration_llava.py index 68ec84b4d3..58bf40d6ce 100644 --- a/src/transformers/models/llava/configuration_llava.py +++ b/src/transformers/models/llava/configuration_llava.py @@ -50,6 +50,8 @@ class LlavaConfig(PretrainedConfig): The index of the layer to select the vision feature. image_seq_length (`int`, *optional*, defaults to 576): Sequence length of one image embedding. + multimodal_projector_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the multimodal projector. Example: @@ -85,6 +87,7 @@ class LlavaConfig(PretrainedConfig): vision_feature_select_strategy="default", vision_feature_layer=-2, image_seq_length=576, + multimodal_projector_bias=True, **kwargs, ): self.ignore_index = ignore_index @@ -127,6 +130,7 @@ class LlavaConfig(PretrainedConfig): text_config = CONFIG_MAPPING["llama"]() self.text_config = text_config + self.multimodal_projector_bias = multimodal_projector_bias super().__init__(**kwargs) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 09c63e39b3..3d9bc339fd 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -86,10 +86,13 @@ class LlavaCausalLMOutputWithPast(ModelOutput): class LlavaMultiModalProjector(nn.Module): def __init__(self, config: LlavaConfig): super().__init__() - - self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) self.act = ACT2FN[config.projector_hidden_act] - self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) def forward(self, image_features): hidden_states = self.linear_1(image_features) diff --git a/src/transformers/models/llava_next/configuration_llava_next.py b/src/transformers/models/llava_next/configuration_llava_next.py index 2251a330aa..6cb76c5b9d 100644 --- a/src/transformers/models/llava_next/configuration_llava_next.py +++ b/src/transformers/models/llava_next/configuration_llava_next.py @@ -55,6 +55,8 @@ class LlavaNextConfig(PretrainedConfig): Whether the model's input and output word embeddings should be tied. image_seq_length (`int`, *optional*, defaults to 576): Sequence length of one image embedding. + multimodal_projector_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the multimodal projector. Example: @@ -92,12 +94,14 @@ class LlavaNextConfig(PretrainedConfig): image_grid_pinpoints=None, tie_word_embeddings=False, image_seq_length=576, + multimodal_projector_bias=True, **kwargs, ): self.ignore_index = ignore_index self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length + self.multimodal_projector_bias = multimodal_projector_bias if vision_feature_select_strategy not in ["default", "full"]: raise ValueError( diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 1a1223e9c2..71e4638989 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -194,10 +194,13 @@ class LlavaNextCausalLMOutputWithPast(ModelOutput): class LlavaNextMultiModalProjector(nn.Module): def __init__(self, config: LlavaNextConfig): super().__init__() - - self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) self.act = ACT2FN[config.projector_hidden_act] - self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) def forward(self, image_features): hidden_states = self.linear_1(image_features) diff --git a/src/transformers/models/llava_next_video/configuration_llava_next_video.py b/src/transformers/models/llava_next_video/configuration_llava_next_video.py index e608e5a0d2..77089ed0f3 100644 --- a/src/transformers/models/llava_next_video/configuration_llava_next_video.py +++ b/src/transformers/models/llava_next_video/configuration_llava_next_video.py @@ -44,6 +44,8 @@ class LlavaNextVideoConfig(PretrainedConfig): The image token index to encode the image prompt. projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): The activation function used by the multimodal projector. + multimodal_projector_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the multimodal projector. vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. @@ -95,6 +97,7 @@ class LlavaNextVideoConfig(PretrainedConfig): ignore_index=-100, image_token_index=32001, projector_hidden_act="gelu", + multimodal_projector_bias=True, vision_feature_select_strategy="default", vision_feature_layer=-2, image_grid_pinpoints=None, @@ -114,6 +117,7 @@ class LlavaNextVideoConfig(PretrainedConfig): self.ignore_index = ignore_index self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act + self.multimodal_projector_bias = multimodal_projector_bias if vision_feature_select_strategy not in ["default", "full"]: raise ValueError( diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index f6a66a7a9b..b1ae26aaac 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -179,10 +179,13 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): class LlavaNextVideoMultiModalProjector(nn.Module): def __init__(self, config: LlavaNextVideoConfig): super().__init__() - - self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) self.act = ACT2FN[config.projector_hidden_act] - self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) def forward(self, image_features): hidden_states = self.linear_1(image_features) diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 5c04c96b88..89975a745b 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -58,6 +58,8 @@ class LlavaNextVideoConfig(PretrainedConfig): The image token index to encode the image prompt. projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): The activation function used by the multimodal projector. + multimodal_projector_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the multimodal projector. vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. @@ -109,6 +111,7 @@ class LlavaNextVideoConfig(PretrainedConfig): ignore_index=-100, image_token_index=32001, projector_hidden_act="gelu", + multimodal_projector_bias=True, vision_feature_select_strategy="default", vision_feature_layer=-2, image_grid_pinpoints=None, @@ -128,6 +131,7 @@ class LlavaNextVideoConfig(PretrainedConfig): self.ignore_index = ignore_index self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act + self.multimodal_projector_bias = multimodal_projector_bias if vision_feature_select_strategy not in ["default", "full"]: raise ValueError( diff --git a/src/transformers/models/llava_onevision/configuration_llava_onevision.py b/src/transformers/models/llava_onevision/configuration_llava_onevision.py index 74be035a8f..504e8a7878 100644 --- a/src/transformers/models/llava_onevision/configuration_llava_onevision.py +++ b/src/transformers/models/llava_onevision/configuration_llava_onevision.py @@ -58,6 +58,8 @@ class LlavaOnevisionConfig(PretrainedConfig): of the form `(height, width)`. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether the model's input and output word embeddings should be tied. + multimodal_projector_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the multimodal projector. Example: @@ -95,11 +97,13 @@ class LlavaOnevisionConfig(PretrainedConfig): vision_aspect_ratio="anyres_max_9", image_grid_pinpoints=None, tie_word_embeddings=False, + multimodal_projector_bias=True, **kwargs, ): self.image_token_index = image_token_index self.video_token_index = video_token_index self.projector_hidden_act = projector_hidden_act + self.multimodal_projector_bias = multimodal_projector_bias if vision_feature_select_strategy not in ["default", "full"]: raise ValueError( diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 4bcdf1aba8..7bc88ec95a 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -201,10 +201,13 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput): class LlavaOnevisionMultiModalProjector(nn.Module): def __init__(self, config: LlavaOnevisionConfig): super().__init__() - - self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) self.act = ACT2FN[config.projector_hidden_act] - self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) def forward(self, image_features): hidden_states = self.linear_1(image_features) diff --git a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py index c4190082d9..a8b3ae5024 100644 --- a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py +++ b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import json +import os import regex as re import torch @@ -27,7 +29,6 @@ from transformers import ( PixtralImageProcessor, PixtralProcessor, PixtralVisionConfig, - PreTrainedTokenizerFast, ) from transformers.convert_slow_tokenizer import bytes_to_unicode @@ -156,29 +157,18 @@ class MistralConverter: return tokenizer -def convert_mistral_tokenizer(): - model_name = "mistralai/Pixtral-12B-2409" +def convert_mistral_tokenizer(model_file): + from transformers import LlamaTokenizer - tokenizer = MistralTokenizer.from_model(model_name) - - vocab = tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial - all_special = [ - token.value if hasattr(token, "value") else token - for token in tokenizer.instruct_tokenizer.tokenizer._all_special_tokens - ] - specials_tokens = {token: all_special.index(token) for token in all_special} - specials_tokens.update(vocab) - vocab = specials_tokens - - tokenizer = PreTrainedTokenizerFast( - tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), - bos_token="", - unk_token="", - eos_token="", - ) - tokenizer.model_input_names = ["input_ids", "attention_mask"] - - return tokenizer + mistral_tokenizer = MistralTokenizer.from_file(model_file) + vocab = mistral_tokenizer.instruct_tokenizer.tokenizer.vocab() + control_token_ids = mistral_tokenizer.instruct_tokenizer.tokenizer._control_tokens + all_special = [vocab[id] for id in control_token_ids] + hf_tokenizer = LlamaTokenizer(model_file) + # Do I need to exclude tokens that are already special? + hf_tokenizer.add_special_tokens({"additional_special_tokens": all_special}) + hf_tokenizer.model_input_names = ["input_ids", "attention_mask"] + return hf_tokenizer def permute_for_rope(value, n_heads, config): @@ -187,7 +177,7 @@ def permute_for_rope(value, n_heads, config): return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) -def convert_dictionnary(original_state_dict, vision_config, text_config): +def convert_dictionary(original_state_dict, vision_config, text_config): new_dict = {} all_keys = "\n" + "\n".join(original_state_dict.keys()) @@ -208,7 +198,6 @@ def convert_dictionnary(original_state_dict, vision_config, text_config): num_attention_heads = _config.num_attention_heads if "k_proj" in new_key: num_attention_heads = _config.num_key_value_heads - # convert the text model (basically mistral model) if "q_proj" in new_key or "k_proj" in new_key: value = permute_for_rope(value, num_attention_heads, _config) @@ -217,29 +206,57 @@ def convert_dictionnary(original_state_dict, vision_config, text_config): return new_dict -def convert_mistral_model(input_dir, output_dir): - text_config = MistralConfig( - attention_dropout=0.0, - bos_token_id=1, - eos_token_id=2, - head_dim=128, - hidden_act="silu", - hidden_size=5120, - initializer_range=0.02, - intermediate_size=14336, - max_position_embeddings=1024000, - model_type="mistral", - num_attention_heads=32, - num_hidden_layers=40, - num_key_value_heads=8, - rms_norm_eps=1e-05, - rope_theta=1000000000.0, - sliding_window=None, - tie_word_embeddings=False, - vocab_size=131072, - ) +MISTRAL_CONFIG_MAPPING = { + "dim": "hidden_size", + "hidden_dim": "intermediate_size", + "n_kv_heads": "num_key_value_heads", + "n_heads": "num_attention_heads", + "n_layers": "num_hidden_layers", +} - vision_config = PixtralVisionConfig() + +def convert_mistral_model(input_dir, output_dir): + vision_config = {} + if os.path.isfile(f"{input_dir}/params.json"): + with open(f"{input_dir}/params.json") as f: + param_json = json.load(f) + vision_config = param_json.pop("vision_encoder") + for k, v in MISTRAL_CONFIG_MAPPING.items(): + value = param_json.pop(k) + param_json[v] = value + if "hidden_act" not in vision_config: + vision_config["hidden_act"] = "silu" + text_config = MistralConfig( + **param_json, + hidden_act="silu", + sliding_window=None, + tie_word_embeddings=False, + is_composition=True, + rms_norm_eps=1e-5, + ) + else: + text_config = MistralConfig( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + head_dim=128, + hidden_act="silu", + hidden_size=5120, + initializer_range=0.02, + intermediate_size=14336, + max_position_embeddings=1024000, + model_type="mistral", + num_attention_heads=32, + num_hidden_layers=40, + num_key_value_heads=8, + rms_norm_eps=1e-05, + rope_theta=1000000000.0, + sliding_window=None, + tie_word_embeddings=False, + vocab_size=131072, + ) + adapter_bias = vision_config.pop("adapter_bias", True) + vision_config = PixtralVisionConfig(**vision_config) config = LlavaConfig( vision_config, text_config, @@ -247,38 +264,55 @@ def convert_mistral_model(input_dir, output_dir): image_token_index=10, vision_feature_select_strategy="full", image_seq_length=1, + multimodal_projector_bias=adapter_bias, ) config.architectures = ["LlavaForConditionalGeneration"] config.save_pretrained(output_dir) + full_original_state_dict = {} + safetensors_files = sorted([file for file in os.listdir(input_dir) if file.endswith(".safetensors")]) + if len(safetensors_files) == 1: + full_original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors") + else: + for file in safetensors_files: + loaded_dict = safe_load_file(f"{input_dir}/{file}") + full_original_state_dict.update(loaded_dict) - original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors") - new_dict = convert_dictionnary(original_state_dict, vision_config, text_config) - + new_dict = convert_dictionary(full_original_state_dict, vision_config, text_config) with torch.device("meta"): model = LlavaForConditionalGeneration(config) model.load_state_dict(new_dict, strict=True, assign=True) - model.save_pretrained(output_dir) - tokenizer = convert_mistral_tokenizer() - image_processor = PixtralImageProcessor() - processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor, image_token="[IMG]") - processor.save_pretrained(output_dir) - def main(): parser = argparse.ArgumentParser() parser.add_argument( "--input_dir", help="Location of LLaMA weights, which contains tokenizer.model and model folders", + required=True, ) parser.add_argument( "--output_dir", help="Location to write HF model and tokenizer", + required=True, + ) + parser.add_argument( + "--tokenizer_file", help="Location of the specific tokenizer model file to use.", required=True + ) + parser.add_argument( + "--chat_template_file", + help="Optional file containing a raw chat template. Will be set as the processor's chat template.", + required=False, ) args = parser.parse_args() convert_mistral_model(args.input_dir, args.output_dir) + tokenizer = convert_mistral_tokenizer(args.tokenizer_file) + image_processor = PixtralImageProcessor() + processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor, image_token="[IMG]") + if args.chat_template_file: + processor.chat_template = open(args.chat_template_file).read() + processor.save_pretrained(args.output_dir) if __name__ == "__main__": diff --git a/src/transformers/models/pixtral/image_processing_pixtral.py b/src/transformers/models/pixtral/image_processing_pixtral.py index 05299949a6..6d83e0c464 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral.py +++ b/src/transformers/models/pixtral/image_processing_pixtral.py @@ -37,7 +37,7 @@ from ...image_utils import ( validate_kwargs, validate_preprocess_arguments, ) -from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, is_vision_available, logging +from ...utils import TensorType, is_torch_device, is_torch_dtype, is_vision_available, logging from ...utils.import_utils import requires_backends @@ -63,10 +63,24 @@ class BatchMixFeature(BatchFeature): Returns: [`BatchFeature`]: The same instance after modification. """ + + def _recursive_to(obj, device, *args, **kwargs): + # Lists can be nested, so keep digging until we hit tensors + if isinstance(obj, list): + return [_recursive_to(o, device, *args, **kwargs) for o in obj] + # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` + elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj): + # cast and send to device + return obj.to(*args, **kwargs) + elif isinstance(obj, torch.Tensor) and device is not None: + # only send to device, don't cast + return obj.to(device=device) + else: + return obj + requires_backends(self, ["torch"]) import torch # noqa - new_data = {} device = kwargs.get("device") # Check if the args are a device or a dtype if device is None and len(args) > 0: @@ -80,21 +94,8 @@ class BatchMixFeature(BatchFeature): else: # it's something else raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") - # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` - for k, v in self.items(): - # check if v is a floating point - if isinstance(v, list): - new_data[k] = [ - element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element) - ] - elif isinstance(v, torch.Tensor) and torch.is_floating_point(v): - # cast and send to device - new_data[k] = v.to(*args, **kwargs) - elif isinstance(v, torch.Tensor) and device is not None: - new_data[k] = v.to(device=device) - else: - new_data[k] = v - self.data = new_data + + self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()} return self diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index 0b7d2dfdd8..905eef22ca 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -126,7 +126,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -491,18 +490,20 @@ class PixtralVisionModel(PixtralPreTrainedModel): all tokens of all images of shape (N_toks, D) """ # pass images through initial convolution independently - patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values] + if len(pixel_values) > 1: + raise ValueError("Batching/padding not supported yet!") + patch_embeds_list = [self.patch_conv(img.to(self.dtype)) for sample in pixel_values for img in sample] # flatten to a single sequence - patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) + patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0) patch_embeds = self.ln_pre(patch_embeds) - # positional embeddings position_ids = position_ids_in_meshgrid( patch_embeds_list, max_width=self.config.image_size // self.config.patch_size ).to(self.device) position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) + attention_mask = generate_block_attention_mask( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds ) diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index 0b5354daad..e60151130a 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -22,7 +22,7 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image, load_image from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order from ...tokenization_utils_base import PreTokenizedInput, TextInput -from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends +from ...utils import is_torch_device, is_torch_dtype, logging, requires_backends logger = logging.get_logger(__name__) @@ -66,10 +66,24 @@ class BatchMixFeature(BatchFeature): Returns: [`BatchFeature`]: The same instance after modification. """ + + def _recursive_to(obj, device, *args, **kwargs): + # Lists can be nested, so keep digging until we hit tensors + if isinstance(obj, list): + return [_recursive_to(o, device, *args, **kwargs) for o in obj] + # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` + elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj): + # cast and send to device + return obj.to(*args, **kwargs) + elif isinstance(obj, torch.Tensor) and device is not None: + # only send to device, don't cast + return obj.to(device=device) + else: + return obj + requires_backends(self, ["torch"]) import torch # noqa - new_data = {} device = kwargs.get("device") # Check if the args are a device or a dtype if device is None and len(args) > 0: @@ -83,21 +97,8 @@ class BatchMixFeature(BatchFeature): else: # it's something else raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") - # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` - for k, v in self.items(): - # check if v is a floating point - if isinstance(v, list): - new_data[k] = [ - element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element) - ] - elif isinstance(v, torch.Tensor) and torch.is_floating_point(v): - # cast and send to device - new_data[k] = v.to(*args, **kwargs) - elif isinstance(v, torch.Tensor) and device is not None: - new_data[k] = v.to(device=device) - else: - new_data[k] = v - self.data = new_data + + self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()} return self @@ -204,12 +205,21 @@ class PixtralProcessor(ProcessorMixin): if images is not None: if is_image_or_image_url(images): - images = [[images]] - elif isinstance(images, list) and is_image_or_image_url(images[0]): - if isinstance(text, list): - images = [[im] for im in images] + if isinstance(text, str) or isinstance(text, list) and len(text) == 1: + # If there's a single sample, the image must belong to it + images = [[images]] else: + raise ValueError( + "You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample." + ) + elif isinstance(images, list) and is_image_or_image_url(images[0]): + if isinstance(text, str) or isinstance(text, list) and len(text) == 1: + # If there's a single sample, all images must belong to it images = [images] + else: + raise ValueError( + "You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample." + ) elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]): pass else: diff --git a/src/transformers/models/video_llava/configuration_video_llava.py b/src/transformers/models/video_llava/configuration_video_llava.py index 6a77b28339..2342e16da4 100644 --- a/src/transformers/models/video_llava/configuration_video_llava.py +++ b/src/transformers/models/video_llava/configuration_video_llava.py @@ -55,6 +55,8 @@ class VideoLlavaConfig(PretrainedConfig): Sequence length of one image embedding. video_seq_length (`int`, *optional*, defaults to 2056): Sequence length of one video embedding. + multimodal_projector_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the multimodal projector. Example: @@ -92,6 +94,7 @@ class VideoLlavaConfig(PretrainedConfig): vision_feature_layer=-2, image_seq_length=256, video_seq_length=2056, + multimodal_projector_bias=True, **kwargs, ): self.ignore_index = ignore_index @@ -102,6 +105,7 @@ class VideoLlavaConfig(PretrainedConfig): self.vision_feature_layer = vision_feature_layer self.image_seq_length = image_seq_length self.video_seq_length = video_seq_length + self.multimodal_projector_bias = multimodal_projector_bias self.vision_config = vision_config diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 80dfaa2f0e..aeff4ad1d0 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -88,10 +88,13 @@ class VideoLlavaCausalLMOutputWithPast(ModelOutput): class VideoLlavaMultiModalProjector(nn.Module): def __init__(self, config: VideoLlavaConfig): super().__init__() - - self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) self.act = ACT2FN[config.projector_hidden_act] - self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) def forward(self, image_features): hidden_states = self.linear_1(image_features) diff --git a/tests/models/pixtral/test_processor_pixtral.py b/tests/models/pixtral/test_processor_pixtral.py index c3496dff3c..d224c53124 100644 --- a/tests/models/pixtral/test_processor_pixtral.py +++ b/tests/models/pixtral/test_processor_pixtral.py @@ -253,7 +253,7 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase): "USER: [IMG]\nWhat's the content of the image? ASSISTANT:", ] * 5 processor.tokenizer.pad_token = "" - image_inputs = [self.image_0] * 5 + image_inputs = [[self.image_0]] * 5 # Make small for checking image token expansion processor.image_processor.size = {"longest_edge": 30}