From 4303d88c097d39e138f47a7946e46943d99bdfdf Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 25 Mar 2025 09:55:21 +0100 Subject: [PATCH] Add Phi4 multimodal (#36939) * raw start * update * update * add to imports * update * up * simplify configs * clean configs * style * typos * Update convert_phi4_multimodal_weights_to_hf.py * Update convert_phi4_multimodal_weights_to_hf.py * fix * up * up * up * Update convert_phi4_multimodal_weights_to_hf.py * Update convert_phi4_multimodal_weights_to_hf.py * up * up * up * Update feature_extraction_phi4_multimodal.py * up * up * up * up * up * simplify configs * typo * cut code * typo * typo * typo * re * typo * up * up * up * add tests * fix * fix * Update test_modeling_phi4_multimodal.py * up * Update test_modeling_phi4_multimodal.py * doc * fix * up * up * up * up * up * up * simplify * up * simplify * config docstrings * cleanup * clean * typo * typo * fix * Update phi4_multimodal.md * fix * fix * Update test_modeling_phi4_multimodal.py * update * simplify reshapes and permutes * up * simplify special tokens * simplify processor a lot * Update processing_phi4_multimodal.py * Update processing_phi4_multimodal.py * switch to fast processor * image processor * Update image_processing_phi4_multimodal_fast.py * add lora extraction to converter * Update convert_phi4_multimodal_weights_to_hf.py * Update __init__.py * add AudioInput type in audio_utils * rewrite feature_extraction: support torch batched FFT * input_audio_embeds -> audio_input_features, input_image_embeds -> image_pixel_values * test update * not mono channel warning update * remove auto maps from processor * kargs dispatch in processor * simplify kwargs dispatch * simplify merging * remove default sampling rate * style * Update test_modeling_phi4_multimodal.py * update doc * doc * torch only feature extractor * make fake tokens adjustable * Update feature_extraction_phi4_multimodal.py * fix * Update processing_phi4_multimodal.py * simplify mask * last touch * fix copies * style * Update audio_utils.py * style * Update feature_extraction_phi4_multimodal.py * Update __init__.py * docstrings * copies * fix all checks * back to fix-copies * trigger CIs * Update feature_extraction_phi4_multimodal.py * improve tests with multimodal inputs * trigger CIs --------- Co-authored-by: Eustache Le Bihan --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/phi4_multimodal.md | 149 ++ src/transformers/__init__.py | 36 + src/transformers/audio_utils.py | 7 +- src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../models/auto/feature_extraction_auto.py | 1 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/auto/processing_auto.py | 1 + .../models/phi4_multimodal/__init__.py | 32 + .../configuration_phi4_multimodal.py | 482 ++++ .../convert_phi4_multimodal_weights_to_hf.py | 229 ++ .../feature_extraction_phi4_multimodal.py | 348 +++ .../image_processing_phi4_multimodal_fast.py | 263 ++ .../modeling_phi4_multimodal.py | 2316 +++++++++++++++++ .../modular_phi4_multimodal.py | 1851 +++++++++++++ .../processing_phi4_multimodal.py | 194 ++ src/transformers/utils/dummy_pt_objects.py | 49 + .../utils/dummy_torchvision_objects.py | 7 + tests/models/phi4_multimodal/__init__.py | 0 .../test_modeling_phi4_multimodal.py | 405 +++ utils/check_docstrings.py | 1 + utils/check_repo.py | 2 + 24 files changed, 6380 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/model_doc/phi4_multimodal.md create mode 100644 src/transformers/models/phi4_multimodal/__init__.py create mode 100644 src/transformers/models/phi4_multimodal/configuration_phi4_multimodal.py create mode 100644 src/transformers/models/phi4_multimodal/convert_phi4_multimodal_weights_to_hf.py create mode 100644 src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py create mode 100644 src/transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py create mode 100644 src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py create mode 100644 src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py create mode 100644 src/transformers/models/phi4_multimodal/processing_phi4_multimodal.py create mode 100644 tests/models/phi4_multimodal/__init__.py create mode 100644 tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 00d898e4d1..bcd054113c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -583,6 +583,8 @@ title: Phi - local: model_doc/phi3 title: Phi-3 + - local: model_doc/phi4_multimodal + title: Phi4 Multimodal - local: model_doc/phimoe title: PhiMoE - local: model_doc/phobert diff --git a/docs/source/en/model_doc/phi4_multimodal.md b/docs/source/en/model_doc/phi4_multimodal.md new file mode 100644 index 0000000000..f0d8bb3b46 --- /dev/null +++ b/docs/source/en/model_doc/phi4_multimodal.md @@ -0,0 +1,149 @@ + + +# Phi4 Multimodal + +## Overview + +Phi4 Multimodal is a lightweight open multimodal foundation model that leverages the language, vision, and speech research and datasets used for Phi-3.5 and 4.0 models. The model processes text, image, and audio inputs, generating text outputs, and comes with 128K token context length. The model underwent an enhancement process, incorporating both supervised fine-tuning, direct preference optimization and RLHF (Reinforcement Learning from Human Feedback) to support precise instruction adherence and safety measures. The languages that each modal supports are the following: + +- Text: Arabic, Chinese, Czech, Danish, Dutch, English, Finnish, French, German, Hebrew, Hungarian, Italian, Japanese, Korean, Norwegian, Polish, Portuguese, Russian, Spanish, Swedish, Thai, Turkish, Ukrainian +- Vision: English +- Audio: English, Chinese, German, French, Italian, Japanese, Spanish, Portuguese + +This model was contributed by [Cyril Vallez](https://huggingface.co/cyrilvallez). The most recent code can be +found [here](https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py). + + +## Usage tips + +`Phi4-multimodal-instruct` can be found on the [Huggingface Hub](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) + +In the following, we demonstrate how to use it for inference depending on the input modalities (text, image, audio). + +```python +import requests +import torch +import os +import io +from PIL import Image +import soundfile as sf +from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig +from urllib.request import urlopen + + +# Define model path +model_path = "microsoft/Phi-4-multimodal-instruct" +device = "cuda:0" + +# Load model and processor +processor = AutoProcessor.from_pretrained(model_path) +model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, torch_dtype=torch.float16) + +# Optional: load the adapters (note that without them, the base model will very likely not work well) +model.load_adapter(model_path, adapter_name="speech", device_map=device, adapter_kwargs={"subfolder": 'speech-lora'}) +model.load_adapter(model_path, adapter_name="vision", device_map=device, adapter_kwargs={"subfolder": 'vision-lora'}) + +# Define prompt structure +user_prompt = '<|user|>' +assistant_prompt = '<|assistant|>' +prompt_suffix = '<|end|>' + +# Part 1: Image Processing +model.set_adapter("vision") # if loaded, activate the vision adapter +print("\n--- IMAGE PROCESSING ---") +image_url = 'https://www.ilankelman.org/stopsigns/australia.jpg' +prompt = f'{user_prompt}<|image_1|>What is shown in this image?{prompt_suffix}{assistant_prompt}' +print(f'>>> Prompt\n{prompt}') + +# Download and open image +image = Image.open(requests.get(image_url, stream=True).raw) +inputs = processor(text=prompt, images=image, return_tensors='pt').to(device) + +# Generate response +generate_ids = model.generate( + **inputs, + max_new_tokens=1000, + do_sample=False, +) +generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] +response = processor.batch_decode( + generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False +)[0] +print(f'>>> Response\n{response}') + +# Part 2: Audio Processing +model.set_adapter("speech") # if loaded, activate the speech adapter +print("\n--- AUDIO PROCESSING ---") +audio_url = "https://upload.wikimedia.org/wikipedia/commons/b/b0/Barbara_Sahakian_BBC_Radio4_The_Life_Scientific_29_May_2012_b01j5j24.flac" +speech_prompt = "Transcribe the audio to text, and then translate the audio to French. Use as a separator between the original transcript and the translation." +prompt = f'{user_prompt}<|audio_1|>{speech_prompt}{prompt_suffix}{assistant_prompt}' +print(f'>>> Prompt\n{prompt}') + +# Downlowd and open audio file +audio, sample_rate = sf.read(io.BytesIO(urlopen(audio_url).read())) + +# Process with the model +inputs = processor(text=prompt, audios=audio, sample_rate=sample_rate, return_tensors='pt').to(device) + +generate_ids = model.generate( + **inputs, + max_new_tokens=1000, + do_sample=False, +) +generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] +response = processor.batch_decode( + generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False +)[0] +print(f'>>> Response\n{response}') +``` + +## Phi4MultimodalFeatureExtractor + +[[autodoc]] Phi4MultimodalFeatureExtractor + +## Phi4MultimodalImageProcessorFast + +[[autodoc]] Phi4MultimodalImageProcessorFast + +## Phi4MultimodalProcessor + +[[autodoc]] Phi4MultimodalProcessor + +## Phi4MultimodalAudioConfig + +[[autodoc]] Phi4MultimodalAudioConfig + +## Phi4MultimodalVisionConfig + +[[autodoc]] Phi4MultimodalVisionConfig + +## Phi4MultimodalConfig + +[[autodoc]] Phi4MultimodalConfig + +## Phi4MultimodalAudioModel + +[[autodoc]] Phi4MultimodalAudioModel + +## Phi4MultimodalVisionModel + +[[autodoc]] Phi4MultimodalVisionModel + +## Phi4MultimodalModel + +[[autodoc]] Phi4MultimodalModel + - forward + +## Phi4MultimodalForCausalLM + +[[autodoc]] Phi4MultimodalForCausalLM + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e5caa17d25..e8da536747 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -699,6 +699,13 @@ _import_structure = { "models.persimmon": ["PersimmonConfig"], "models.phi": ["PhiConfig"], "models.phi3": ["Phi3Config"], + "models.phi4_multimodal": [ + "Phi4MultimodalAudioConfig", + "Phi4MultimodalConfig", + "Phi4MultimodalFeatureExtractor", + "Phi4MultimodalProcessor", + "Phi4MultimodalVisionConfig", + ], "models.phimoe": ["PhimoeConfig"], "models.phobert": ["PhobertTokenizer"], "models.pix2struct": [ @@ -1348,6 +1355,7 @@ else: _import_structure["models.llava"].append("LlavaImageProcessorFast") _import_structure["models.llava_next"].append("LlavaNextImageProcessorFast") _import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast") + _import_structure["models.phi4_multimodal"].append("Phi4MultimodalImageProcessorFast") _import_structure["models.pixtral"].append("PixtralImageProcessorFast") _import_structure["models.qwen2_vl"].append("Qwen2VLImageProcessorFast") _import_structure["models.rt_detr"].append("RTDetrImageProcessorFast") @@ -2802,6 +2810,17 @@ else: "LlavaNextPreTrainedModel", ] ) + _import_structure["models.phi4_multimodal"].extend( + [ + "Phi4MultimodalForCausalLM", + "Phi4MultimodalPreTrainedModel", + "Phi4MultimodalAudioModel", + "Phi4MultimodalAudioPreTrainedModel", + "Phi4MultimodalModel", + "Phi4MultimodalVisionModel", + "Phi4MultimodalVisionPreTrainedModel", + ] + ) _import_structure["models.llava_next_video"].extend( [ "LlavaNextVideoForConditionalGeneration", @@ -5914,6 +5933,13 @@ if TYPE_CHECKING: ) from .models.phi import PhiConfig from .models.phi3 import Phi3Config + from .models.phi4_multimodal import ( + Phi4MultimodalAudioConfig, + Phi4MultimodalConfig, + Phi4MultimodalFeatureExtractor, + Phi4MultimodalProcessor, + Phi4MultimodalVisionConfig, + ) from .models.phimoe import PhimoeConfig from .models.phobert import PhobertTokenizer from .models.pix2struct import ( @@ -6587,6 +6613,7 @@ if TYPE_CHECKING: from .models.llava import LlavaImageProcessorFast from .models.llava_next import LlavaNextImageProcessorFast from .models.llava_onevision import LlavaOnevisionImageProcessorFast + from .models.phi4_multimodal import Phi4MultimodalImageProcessorFast from .models.pixtral import PixtralImageProcessorFast from .models.qwen2_vl import Qwen2VLImageProcessorFast from .models.rt_detr import RTDetrImageProcessorFast @@ -8153,6 +8180,15 @@ if TYPE_CHECKING: Phi3Model, Phi3PreTrainedModel, ) + from .models.phi4_multimodal import ( + Phi4MultimodalAudioModel, + Phi4MultimodalAudioPreTrainedModel, + Phi4MultimodalForCausalLM, + Phi4MultimodalModel, + Phi4MultimodalPreTrainedModel, + Phi4MultimodalVisionModel, + Phi4MultimodalVisionPreTrainedModel, + ) from .models.phimoe import ( PhimoeForCausalLM, PhimoeForSequenceClassification, diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index f54e5375c1..5795b5d9bd 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -17,11 +17,16 @@ and remove unnecessary dependencies. """ import warnings -from typing import Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np +AudioInput = Union[ + np.ndarray, "torch.Tensor", List[np.ndarray], Tuple[np.ndarray], List["torch.Tensor"], Tuple["torch.Tensor"] # noqa: F821 +] + + def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]: """ Convert frequency from hertz to mels. diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 06575ffceb..49ce48dd6c 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -212,6 +212,7 @@ from . import ( persimmon, phi, phi3, + phi4_multimodal, phimoe, phobert, pix2struct, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 712450e166..c7ef472882 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -235,6 +235,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("persimmon", "PersimmonConfig"), ("phi", "PhiConfig"), ("phi3", "Phi3Config"), + ("phi4_multimodal", "Phi4MultimodalConfig"), ("phimoe", "PhimoeConfig"), ("pix2struct", "Pix2StructConfig"), ("pixtral", "PixtralVisionConfig"), @@ -587,6 +588,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("persimmon", "Persimmon"), ("phi", "Phi"), ("phi3", "Phi3"), + ("phi4_multimodal", "Phi4Multimodal"), ("phimoe", "Phimoe"), ("phobert", "PhoBERT"), ("pix2struct", "Pix2Struct"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 134571014f..0b8b38bc34 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -78,6 +78,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( ("nat", "ViTFeatureExtractor"), ("owlvit", "OwlViTFeatureExtractor"), ("perceiver", "PerceiverFeatureExtractor"), + ("phi4_multimodal", "Phi4MultimodalFeatureExtractor"), ("poolformer", "PoolFormerFeatureExtractor"), ("pop2piano", "Pop2PianoFeatureExtractor"), ("regnet", "ConvNextFeatureExtractor"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 7cd47bc060..77b9734189 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -124,6 +124,7 @@ else: ("owlvit", ("OwlViTImageProcessor",)), ("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")), ("perceiver", ("PerceiverImageProcessor",)), + ("phi4_multimodal", "Phi4MultimodalImageProcessorFast"), ("pix2struct", ("Pix2StructImageProcessor",)), ("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")), ("poolformer", ("PoolFormerImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 2dcd1fa3ef..05a4157414 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -218,6 +218,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("persimmon", "PersimmonModel"), ("phi", "PhiModel"), ("phi3", "Phi3Model"), + ("phi4_multimodal", "Phi4MultimodalModel"), ("phimoe", "PhimoeModel"), ("pixtral", "PixtralVisionModel"), ("plbart", "PLBartModel"), @@ -566,6 +567,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("persimmon", "PersimmonForCausalLM"), ("phi", "PhiForCausalLM"), ("phi3", "Phi3ForCausalLM"), + ("phi4_multimodal", "Phi4MultimodalForCausalLM"), ("phimoe", "PhimoeForCausalLM"), ("plbart", "PLBartForCausalLM"), ("prophetnet", "ProphetNetForCausalLM"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 5b699a4a44..48081b9df8 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -91,6 +91,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ("owlv2", "Owlv2Processor"), ("owlvit", "OwlViTProcessor"), ("paligemma", "PaliGemmaProcessor"), + ("phi4_multimodal", "Phi4MultimodalProcessor"), ("pix2struct", "Pix2StructProcessor"), ("pixtral", "PixtralProcessor"), ("pop2piano", "Pop2PianoProcessor"), diff --git a/src/transformers/models/phi4_multimodal/__init__.py b/src/transformers/models/phi4_multimodal/__init__.py new file mode 100644 index 0000000000..c4e2e599f5 --- /dev/null +++ b/src/transformers/models/phi4_multimodal/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_phi4_multimodal import * + from .feature_extraction_phi4_multimodal import * + from .image_processing_phi4_multimodal_fast import * + from .modeling_phi4_multimodal import * + from .processing_phi4_multimodal import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/phi4_multimodal/configuration_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/configuration_phi4_multimodal.py new file mode 100644 index 0000000000..3f776b0b71 --- /dev/null +++ b/src/transformers/models/phi4_multimodal/configuration_phi4_multimodal.py @@ -0,0 +1,482 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_phi4_multimodal.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +from ...configuration_utils import PretrainedConfig + + +class Phi4MultimodalVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi4MultimodalVisionModel`]. It is used to instantiate a + Phi4Multimodal vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 4304): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 27): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 448): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + crop_size (`int`, *optional*, defaults to 448): + Crop size for the input images. + image_token_id (`int`, *optional*, defaults to 200010): + The image token id. + feature_layer (`int`, *optional*, defaults to -2): + The index of the layer of the encoder from which to extract image features. + + Example: + + ```python + >>> from transformers import Phi4MultimodalVisionConfig + + >>> # Initializing a Phi4MultimodalVisionConfig with microsoft/Phi-4-multimodal-instruct style configuration + >>> configuration = Phi4MultimodalVisionConfig() + ```""" + + model_type = "phi4_multimodal_vision" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=1152, + intermediate_size=4304, + num_hidden_layers=27, + num_attention_heads=16, + num_channels=3, + image_size=448, + patch_size=14, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + crop_size: int = 448, + image_token_id: int = 200010, + feature_layer: int = -2, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.crop_size = crop_size + self.image_token_id = image_token_id + self.feature_layer = feature_layer + + +class Phi4MultimodalAudioConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi4MultimodalAudioModel`]. It is used to instantiate a + Phi4Multimodal audio encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers. + intermediate_size (`int`, *optional*, defaults to 1536): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_blocks (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + activation (`str`, *optional*, defaults to `"swish"`): + The non-linear activation function in the MLPs. + chunk_size (`int`, *optional*, defaults to -1): + The chunk size to create the masks. + left_chunk (`int`, *optional*, defaults to 18): + The left chunk to create the masks. + dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout ratio. + ext_pw_out_channel (`int`, *optional*, defaults to 1024): + Number of out channels in the point-wise conv modules. + depthwise_seperable_out_channel (`int`, *optional*, defaults to 1024): + Number of out channels in the depth-wise separable conv modules. + depthwise_multiplier (`int`, *optional*, defaults to 1): + Input size multiplier for the depth-wise separable conv modules. + kernel_size (`int`, *optional*, defaults to 3): + Kernel size for the depth-wise separable conv modules. + conv_activation (`str`, *optional*, defaults to `"swish"`): + The non-linear activation function in the conv modules. + input_size (`int`, *optional*, defaults to 80): + Input size for the audio model. + conv_glu_type (`str`, *optional*, defaults to `"swish"`): + The non-linear activation function in the point-wise conv modules. + time_reduction (`int`, *optional*, defaults to 8): + Time reduction (subsampling factor). + bias_max_distance (`int`, *optional*, defaults to 1000): + Max distance for the relative attention bias module. + bias_symmetric (`bool`, *optional*, defaults to `False`): + Whether the relative attention bias should be symmetric or not. + nemo_activation (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in the nemo conv modules. + nemo_conv_channels (`int`, *optional*, defaults to 1024): + Number of channels in the nemo conv modules. + downsample_rate (`int`, *optional*, defaults to 1): + Downsample rate for the audio feature extractor. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + audio_token_id (`int`, *optional*, defaults to 200011): + The audio token id. + feature_layer (`int`, *optional*, defaults to -2): + The index of the layer of the encoder from which to extract audio features. + + Example: + + ```python + >>> from transformers import Phi4MultimodalAudioConfig + + >>> # Initializing a Phi4MultimodalAudioConfig with microsoft/Phi-4-multimodal-instruct style configuration + >>> configuration = Phi4MultimodalAudioConfig() + ```""" + + model_type = "phi4_multimodal_audio" + + def __init__( + self, + hidden_size: int = 1024, + intermediate_size: int = 1536, + num_blocks: int = 24, + num_attention_heads: int = 16, + activation: str = "swish", + chunk_size: int = -1, + left_chunk: int = 18, + dropout_rate: float = 0.0, + ext_pw_out_channel: int = 1024, + depthwise_seperable_out_channel: int = 1024, + depthwise_multiplier: int = 1, + kernel_size: int = 3, + conv_activation: str = "swish", + input_size: int = 80, + conv_glu_type: str = "swish", + time_reduction: int = 8, + bias_max_distance: int = 1000, + bias_symmetric: bool = False, + nemo_activation: str = "relu", + nemo_conv_channels: int = 1024, + downsample_rate: int = 1, + initializer_range: float = 0.02, + audio_token_id: int = 200011, + feature_layer: int = -2, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.activation = activation + self.chunk_size = chunk_size + self.left_chunk = left_chunk + self.num_blocks = num_blocks + self.dropout_rate = dropout_rate + self.ext_pw_out_channel = ext_pw_out_channel + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + self.depthwise_multiplier = depthwise_multiplier + self.kernel_size = kernel_size + self.conv_activation = conv_activation + self.input_size = input_size + self.conv_glu_type = conv_glu_type + self.time_reduction = time_reduction + self.bias_max_distance = bias_max_distance + self.bias_symmetric = bias_symmetric + self.nemo_activation = nemo_activation + self.nemo_conv_channels = nemo_conv_channels + self.downsample_rate = downsample_rate + self.audio_token_id = audio_token_id + self.initializer_range = initializer_range + self.feature_layer = feature_layer + + if time_reduction % 2 != 0: + raise ValueError("`time_reduction` should be a multiple of 2!") + length = input_size + for _ in range(int(math.log(time_reduction, 2))): + length = math.floor((length - 1) / 2 + 1) + self.nemo_final_size = length + + +class Phi4MultimodalConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi4MultimodalModel`]. It is used to instantiate a + Phi4Multimodal model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 200064): + Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3Model`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + partial_rotary_factor (`float`, *optional*, defaults to `1.0`): + Percentage of the query and keys which will have rotary embedding. Must be between 0.0 and 1.0. + bos_token_id (`int`, *optional*, defaults to 199999): + The id of the "beginning-of-sequence" token. + eos_token_id (`int` or `list[int]`, *optional*, defaults to `[199999, 200020]`): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 199999): + The id of the padding token. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + vision_config (`Phi4MultimodalVisionConfig` or `dict`, *optional*): + The vision config for the underlying image embedding model. If not provided, will default to the configuration + used to instantiate a model similar in architecture as + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct). + audio_config (`Phi4MultimodalAudioConfig` or `dict`, *optional*): + The audio config for the underlying audio embedding model. If not provided, will default to the configuration + used to instantiate a model similar in architecture as + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct). + + Example: + + ```python + >>> from transformers import Phi4MultimodalModel, Phi4MultimodalConfig + + >>> # Initializing a Phi4Multimodal style configuration + >>> configuration = Phi4MultimodalConfig.from_pretrained("microsoft/Phi-4-multimodal-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi4MultimodalModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phi4_multimodal" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.qkv_proj": "colwise_rep", # we need to replicate here due to the slicing of qkv + "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the slicing of qkv + "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation + "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + sub_configs = {"audio_config": Phi4MultimodalAudioConfig, "vision_config": Phi4MultimodalVisionConfig} + + def __init__( + self, + vocab_size=200064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=131072, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=1, + bos_token_id=199999, + eos_token_id=[199999, 200020], + pad_token_id=199999, + original_max_position_embeddings=4096, + sliding_window=None, + vision_config=None, + audio_config=None, + **kwargs, + ): + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self._rope_scaling_adjustment() + self._rope_scaling_validation() + self.sliding_window = sliding_window + + if isinstance(vision_config, dict): + vision_config = Phi4MultimodalVisionConfig(**vision_config) + elif vision_config is None: + Phi4MultimodalVisionConfig() + self.vision_config = vision_config + + if isinstance(audio_config, dict): + audio_config = Phi4MultimodalAudioConfig(**audio_config) + elif vision_config is None: + audio_config = Phi4MultimodalAudioConfig() + self.audio_config = audio_config + + def _rope_scaling_adjustment(self): + """ + Adjust the `type` of the `rope_scaling` configuration for backward compatibility. + """ + if self.rope_scaling is None: + return + + rope_scaling_type = self.rope_scaling.get("type", None) + + # For backward compatibility if previous version used "su" or "yarn" + if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]: + self.rope_scaling["type"] = "longrope" + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: + raise ValueError( + "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) + rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["longrope"]: + raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}") + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor) + if not len(rope_scaling_short_factor) == rotary_ndims // 2: + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if not len(rope_scaling_long_factor) == rotary_ndims // 2: + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}" + ) + + +__all__ = ["Phi4MultimodalVisionConfig", "Phi4MultimodalAudioConfig", "Phi4MultimodalConfig"] diff --git a/src/transformers/models/phi4_multimodal/convert_phi4_multimodal_weights_to_hf.py b/src/transformers/models/phi4_multimodal/convert_phi4_multimodal_weights_to_hf.py new file mode 100644 index 0000000000..c7cae2ab00 --- /dev/null +++ b/src/transformers/models/phi4_multimodal/convert_phi4_multimodal_weights_to_hf.py @@ -0,0 +1,229 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import re + +import torch +from peft import LoraConfig +from safetensors.torch import load_file, save_file + +from transformers import ( + Phi4MultimodalAudioConfig, + Phi4MultimodalConfig, + Phi4MultimodalForCausalLM, + Phi4MultimodalProcessor, + Phi4MultimodalVisionConfig, +) + + +# fmt: off +STATE_DICT_MAPPING = { + r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).feed_forward_(in|out).net.0.linear": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.feed_forward_\2.gate_up_proj", + r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).feed_forward_(in|out).net.2": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.feed_forward_\2.down_proj", + + r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).self_attn.linear_(q|k|v)": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.self_attn.\2_proj", + r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).self_attn.linear_out": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.self_attn.o_proj", + + r"^model.embed_tokens_extend.image_embed.img_projection.0": r"model.embed_tokens_extend.image_embed.img_projection_up", + r"^model.embed_tokens_extend.image_embed.img_projection.2": r"model.embed_tokens_extend.image_embed.img_projection_down", + + r"^model.embed_tokens_extend.image_embed.glb_GN": r"model.embed_tokens_extend.image_embed.global_img_feature_extensor", + r"^model.embed_tokens_extend.image_embed.sub_GN": r"model.embed_tokens_extend.image_embed.sub_img_feature_extensor", + + r"^model.embed_tokens_extend.audio_embed.audio_projection.speech.0": r"model.embed_tokens_extend.audio_embed.up_proj_for_speech", + r"^model.embed_tokens_extend.audio_embed.audio_projection.speech.2": r"model.embed_tokens_extend.audio_embed.down_proj_for_speech", + r"^model.embed_tokens_extend.audio_embed.audio_projection.vision.0": r"model.embed_tokens_extend.audio_embed.up_proj_for_vision_speech", + r"^model.embed_tokens_extend.audio_embed.audio_projection.vision.2": r"model.embed_tokens_extend.audio_embed.down_proj_for_vision_speech", +} +# fmt: on + + +def map_old_key_to_new(old_key): + """Map of a key of the original state dict to the equivalent key in HF format""" + for pattern, replacement in STATE_DICT_MAPPING.items(): + new_key, n_replace = re.subn(pattern, replacement, old_key) + # Early exit of the loop + if n_replace > 0: + return new_key + + # The state dict contains lora keys.... + if "lora" in old_key: + return None + # This extracts the original weight before adding the lora adapter + if "base_layer." in old_key: + return old_key.replace("base_layer.", "") + + # not part of the key mapping, we keep the original name + return old_key + + +def convert_state_dict(original_state_dict: dict): + """Convert a state dict file.""" + new_dict = {} + for old_key, tensor in original_state_dict.items(): + new_key = map_old_key_to_new(old_key) + if new_key is not None: + new_dict[new_key] = tensor + return new_dict + + +def convert_config(original_config: dict): + # Remove unused args + original_config.pop("_name_or_path", None) + original_config.pop("architectures", None) + original_config.pop("auto_map", None) + original_config.pop("vision_lora", None) + original_config.pop("speech_lora", None) + original_config.pop("transformers_version", None) + original_config.pop("_attn_implementation", None) + + embd_layer = original_config.pop("embd_layer") + audio_embd_layer = embd_layer["audio_embd_layer"] + vision_embd_layer = embd_layer["image_embd_layer"] + + # Keep only some of the subdict + keep_audio_embd_layer = ["downsample_rate"] + keep_vision_embd_layer = ["crop_size"] + audio_embd_layer = {k: v for k, v in audio_embd_layer.items() if k in keep_audio_embd_layer} + vision_embd_layer = {k: v for k, v in vision_embd_layer.items() if k in keep_vision_embd_layer} + + audio_config = original_config.pop("audio_processor")["config"] + # remove + audio_config.pop("activation_checkpointing", None) + audio_config.pop("cnn_layer_norm", None) + audio_config.pop("input_layer", None) + audio_config.pop("batch_norm", None) + audio_config.pop("encoder_embedding_config", None) + audio_config.pop("ext_pw_kernel_size", None) + audio_config.pop("bias_in_glu", None) + audio_config.pop("causal", None) + # rename + audio_config["hidden_size"] = audio_config.pop("attention_dim") + audio_config["num_attention_heads"] = audio_config.pop("attention_heads") + audio_config["intermediate_size"] = audio_config.pop("linear_units") + audio_config["nemo_conv_channels"] = audio_config.pop("nemo_conv_settings")["conv_channels"] + audio_config["bias_max_distance"] = audio_config.pop("relative_attention_bias_args")["t5_bias_max_distance"] + # add + audio_config = {**audio_config, **audio_embd_layer} + + # Create transformers config objects + audio_config = Phi4MultimodalAudioConfig(**audio_config) + vision_config = Phi4MultimodalVisionConfig(**vision_embd_layer) + + # Add 2nd eos to config + original_config["eos_token_id"] = [199999, 200020] + + new_config = Phi4MultimodalConfig(**original_config, vision_config=vision_config, audio_config=audio_config) + return new_config + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def convert_and_write_model(input_dir: str, output_dir: str): + """Convert the model and save it (this implicitly save the config as well).""" + original_config = read_json(os.path.join(input_dir, "config.json")) + config = convert_config(original_config) + + full_state_dict = {} + shards = [file for file in os.listdir(input_dir) if file.endswith(".safetensors")] + for shard_file in shards: + original_state_dict = load_file(os.path.join(input_dir, shard_file)) + new_dict = convert_state_dict(original_state_dict) + full_state_dict.update(new_dict) + + # Load weights into model and resave them + with torch.device("meta"): + model = Phi4MultimodalForCausalLM(config) + missing, unexpected = model.load_state_dict(full_state_dict, strict=False, assign=True) + # The lm_head is missing because it's tied + if missing != ["lm_head.weight"]: + raise ValueError("Missing keys:\n{missing}") + if len(unexpected) > 0: + raise ValueError(f"Unexpected keys:\n{unexpected}") + + model.tie_weights() + model.save_pretrained(output_dir) + + +def convert_and_save_processor(input_dir: str, output_dir: str): + """Convert the processor.""" + processor = Phi4MultimodalProcessor.from_pretrained(input_dir) + del processor.image_processor.auto_map + del processor.audio_processor.auto_map + processor.chat_template = processor.tokenizer.chat_template + processor.tokenizer.extra_special_tokens = {"image_token": "<|endoftext10|>", "audio_token": "<|endoftext11|>"} + processor.save_pretrained(output_dir) + + +def extract_adapters_data(input_dir: str, output_dir: str): + """Extract adapters data from the state dict and save weights and configs.""" + speech_lora = {} + vision_lora = {} + shards = [file for file in os.listdir(input_dir) if file.endswith(".safetensors")] + for shard_file in shards: + original_state_dict = load_file(os.path.join(input_dir, shard_file)) + for k, v in original_state_dict.items(): + if "lora" in k: + if "speech" in k: + speech_lora[k.replace("speech.", "")] = v + elif "vision" in k: + vision_lora[k.replace("vision.", "")] = v + + # Create and save the lora configs + speech_lora_config = LoraConfig( + r=320, + lora_alpha=640, + target_modules=r"model.layers.\d+.((self_attn.(qkv|o)_proj)|(mlp.(gate_up|down)_proj))", + lora_dropout=0.01, + task_type="CAUSAL_LM", + ) + speech_lora_config.save_pretrained(os.path.join(output_dir, "speech-lora")) + vision_lora_config = LoraConfig( + r=256, + lora_alpha=512, + target_modules=r"model.layers.\d+.((self_attn.(qkv|o)_proj)|(mlp.(gate_up|down)_proj))", + lora_dropout=0.0, + task_type="CAUSAL_LM", + ) + vision_lora_config.save_pretrained(os.path.join(output_dir, "vision-lora")) + + save_file(speech_lora, os.path.join(output_dir, "speech-lora", "adapter_model.safetensors")) + save_file(vision_lora, os.path.join(output_dir, "vision-lora", "adapter_model.safetensors")) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "input_dir", + help="Location of the model folder containing the weights and configs.", + ) + parser.add_argument( + "output_dir", + help="Location to write HF model.", + ) + args = parser.parse_args() + + # Convert + convert_and_write_model(args.input_dir, args.output_dir) + convert_and_save_processor(args.input_dir, args.output_dir) + extract_adapters_data(args.input_dir, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py new file mode 100644 index 0000000000..5d29af6c8b --- /dev/null +++ b/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py @@ -0,0 +1,348 @@ +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Processor class for Phi4Multimodal +""" + +from typing import Optional, Union + +import numpy as np + +from ...audio_utils import AudioInput +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...image_processing_utils import BatchFeature +from ...utils import TensorType, is_torch_available, logging + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +# TODO: @eustlb, remove this once #36603 is merged. +def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): + """Create a Mel filter-bank the same as SpeechLib FbankFC. + + Args: + sample_rate (int): Sample rate in Hz. number > 0 [scalar] + n_fft (int): FFT size. int > 0 [scalar] + n_mel (int): Mel filter size. int > 0 [scalar] + fmin (float): lowest frequency (in Hz). If None use 0.0. + float >= 0 [scalar] + fmax: highest frequency (in Hz). If None use sample_rate / 2. + float >= 0 [scalar] + + Returns + out (numpy.ndarray): Mel transform matrix + [shape=(n_mels, 1 + n_fft/2)] + """ + + bank_width = int(n_fft // 2 + 1) + if fmax is None: + fmax = sample_rate / 2 + if fmin is None: + fmin = 0 + assert fmin >= 0, "fmin cannot be negtive" + assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]" + + def mel(f): + return 1127.0 * np.log(1.0 + f / 700.0) + + def bin2mel(fft_bin): + return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0)) + + def f2bin(f): + return int((f * n_fft / sample_rate) + 0.5) + + # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1] + klo = f2bin(fmin) + 1 + khi = f2bin(fmax) + + khi = max(khi, klo) + + # Spec 2: SpeechLib uses trianges in Mel space + mlo = mel(fmin) + mhi = mel(fmax) + m_centers = np.linspace(mlo, mhi, n_mels + 2) + ms = (mhi - mlo) / (n_mels + 1) + + matrix = np.zeros((n_mels, bank_width), dtype=np.float32) + for m in range(0, n_mels): + left = m_centers[m] + center = m_centers[m + 1] + right = m_centers[m + 2] + for fft_bin in range(klo, khi): + mbin = bin2mel(fft_bin) + if left < mbin < right: + matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms + + return matrix + + +class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor): + model_input_names = ["audio_input_features", "audio_embed_sizes", "audio_attention_mask"] + + def __init__( + self, + feature_size: int = 80, + sampling_rate: int = 16000, + hop_length: int = 160, + n_fft: int = 512, + win_length: int = 400, + preemphasis: float = 0.97, + padding_value: float = 0.0, + audio_compression_rate: int = 8, + audio_downsample_rate: int = 1, + audio_feat_stride: int = 1, + mel_min_frequency: float = 0, + mel_max_frequency: float = 7690, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + + self.hop_length = hop_length + self.n_fft = n_fft + self.win_length = win_length + self.preemphasis = preemphasis + self.padding_value = padding_value + self.audio_compression_rate = audio_compression_rate + self.audio_downsample_rate = audio_downsample_rate + self.audio_feat_stride = audio_feat_stride + + # TODO: @eustlb, uncomment and remove speechlib_mel once #36603 is merged. + # self.mel_filters = mel_filter_bank( + # num_frequency_bins=self.n_fft // 2 + 1, + # num_mel_filters=self.feature_size, + # min_frequency=mel_min_frequency, + # max_frequency=mel_max_frequency, + # sampling_rate=self.sampling_rate, + # triangularize_in_mel_space=True, + # mel_scale="kaldi", + # ) + self.mel_filters = speechlib_mel( + self.sampling_rate, self.n_fft, self.feature_size, mel_min_frequency, mel_max_frequency + ).T + + def __call__( + self, + raw_speech: AudioInput, + sampling_rate: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + padding: Optional[str] = "longest", + max_length: Optional[int] = None, + truncation: bool = False, + return_tensors: Optional[Union[str, TensorType]] = None, + return_attention_mask: Optional[bool] = True, + device: Optional[str] = "cpu", + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several audio sequence(s). Implementation uses PyTorch for + the STFT computation if available, otherwise a slower NumPy based one. + + Args: + raw_speech (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The sequence or batch of sequences to be processed. Each sequence can be a numpy array or PyTorch tensor. + For batched inputs, sequences can be a list of numpy arrays or PyTorch tensors, or a single numpy array or + PyTorch tensor with first dimension being the batch size. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + pad_to_multiple_of (`int`, *optional*, defaults to None): + If set will pad the sequence to a multiple of the provided value. + padding (`str`, *optional*, defaults to "longest"): + Padding strategy. Can be "longest" to pad to the longest sequence in the batch, or a specific length. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length. + truncation (`bool`, *optional*, defaults to False): + Activates truncation to cut input sequences longer than *max_length* to *max_length*. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of numpy arrays. Acceptable values are: + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + - `'tf'`: Return TensorFlow `tf.constant` objects. + return_attention_mask (`bool`, *optional*, defaults to `True`): + Whether to return the extracted audio input features' attention mask. + device (`str`, *optional*, defaults to "cpu"): + Specifies the device for computation of the audio features. (e.g., "cpu", "cuda") + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + - **audio_input_features** -- Audio features extracted from the raw audio input, shape (batch_size, max_feature_length, feature_size). + - **audio_lengths** -- Length of each audio sample in the batch, shape (batch_size,). + - **audio_attention_mask** -- Attention mask for the audio input, shape (batch_size, max_feature_length). + If `return_tensors` is not specified, the fields will be PyTorch tensors if PyTorch is available, otherwise NumPy arrays. + """ + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" + f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" + f" was sampled with {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + # Convert to torch tensor + if isinstance(raw_speech, np.ndarray): + raw_speech = torch.tensor(raw_speech) + elif isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], np.ndarray): + raw_speech = [torch.tensor(speech) for speech in raw_speech] + + is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1 + if is_batched_torch and len(raw_speech.shape) > 2: + logger.warning( + f"Only mono-channel audio is supported for input to {self.__class__.__name__}. " + "We will take the mean of the channels to convert to mono." + ) + raw_speech = raw_speech.mean(-1) + + is_batched_sequence = isinstance(raw_speech, (list, tuple)) + if is_batched_sequence: + for speech in raw_speech: + if len(speech.shape) > 1: + logger.warning( + f"Only mono-channel audio is supported for input to {self.__class__.__name__}. " + "We will take the mean of the channels to convert to mono." + ) + speech = speech.mean(-1) + + if is_batched_torch or is_batched_sequence: + raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech] + else: + raw_speech = [raw_speech[:, None].to(torch.float32)] + + audio_lengths = [len(speech) for speech in raw_speech] + + # convert into correct format for padding + batched_speech = BatchFeature(data={"audio_input_features": raw_speech, "audio_lengths": audio_lengths}) + padded_inputs = self.pad( + batched_speech, + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ) + input_features = padded_inputs.audio_input_features.squeeze(-1) + audio_lengths = padded_inputs.audio_lengths + + input_features = self._torch_extract_fbank_features(input_features, audio_lengths, device) + + feature_lengths = (audio_lengths - self.win_length) // self.hop_length + 1 + feature_lengths = feature_lengths * self.audio_feat_stride + audio_embed_sizes = self._compute_audio_embed_size(feature_lengths) + + feature_attention_mask = ( + torch.arange(0, feature_lengths.max()) if is_torch_available() else np.arange(0, feature_lengths.max()) + ) + feature_attention_mask = ( + feature_attention_mask[None, :] < feature_lengths[:, None] if len(feature_lengths) > 1 else None + ) + + data = { + "audio_input_features": input_features, + "audio_embed_sizes": audio_embed_sizes, + } + if feature_attention_mask is not None and return_attention_mask: + data["audio_attention_mask"] = feature_attention_mask + + return BatchFeature(data=data, tensor_type=return_tensors) + + # TODO; @eustlb, move this to audio_utils in a general spectogram_batch function that handles torch and numpy + def _torch_extract_fbank_features( + self, waveform: "torch.FloatTensor", audio_lengths: "torch.Tensor", device: str = "cpu" + ) -> "torch.FloatTensor": + """ + Compute the log mel-scaled spectrogram of batched waveforms using PyTorch's FFT implementation. + + Args: + waveform (torch.FloatTensor` of shape `(batch_size, max_audio_length)`): + The batched waveforms. + audio_lengths (`torch.Tensor` of shape `(batch_size,)`): + The lengths of the waveforms along the max_audio_length dimension. + device (`str`, *optional*, defaults to "cpu"): + The device to run the computation on. (e.g., "cpu", "cuda") + + Returns: + `torch.FloatTensor` of shape `(batch_size, max_feature_length, feature_size)`: + The log mel-scaled spectrogram of the batched waveforms. + """ + fft_window = torch.hamming_window(self.win_length, periodic=False, device=device, dtype=torch.float64) + + # batched implementation + batch_size = waveform.shape[0] + frames = waveform.unfold(-1, self.win_length, self.hop_length) + + # --- + # the unbatched (and unpaded) original implementation skips last few audio values that can't be included in a frame + # we need to ensure that the corresponding frames for the padded input also mask these values + if batch_size > 1: + frames = frames.clone() + # concerned batch indices + to_mask_batch_idxs = torch.arange(batch_size)[audio_lengths != audio_lengths.max()] + if to_mask_batch_idxs.numel() > 0: + batch_idxs_down = (audio_lengths[to_mask_batch_idxs] - self.win_length) // self.hop_length + 1 + batch_idxs_up = audio_lengths[to_mask_batch_idxs] // self.hop_length + 1 + offset_idx = batch_idxs_down.min() + max_idx = batch_idxs_up.max() + + mask = torch.arange(max_idx - offset_idx, device=device).expand(to_mask_batch_idxs.shape[0], -1) + mask = ((batch_idxs_down - offset_idx).unsqueeze(1) <= mask) & ( + mask < (batch_idxs_up - offset_idx).unsqueeze(1) + ) + mask = mask.unsqueeze(-1).expand(-1, -1, self.win_length) + masked_frames = frames[to_mask_batch_idxs, offset_idx:max_idx].masked_fill_(mask, 0) + frames[to_mask_batch_idxs, offset_idx:max_idx] = masked_frames + # --- + + # apply pre-emphasis first order filter on fft windows + frames_prev = torch.roll(frames, 1, dims=-1) + frames_prev[:, :, 0] = frames_prev[:, :, 1] + frames = (frames - self.preemphasis * frames_prev) * 32768 + + # apply fft + S = torch.fft.rfft(fft_window * frames.view(-1, self.win_length), n=self.n_fft, dim=1) + S = S.view(frames.shape[0], -1, S.shape[-1]) + S = S.to(torch.complex64) + + spec = torch.abs(S) + spec_power = spec**2 + + # apply triangular mel filter bank + mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) + log_spec = torch.clamp(spec_power @ mel_filters, min=1.0) + log_spec = torch.log(log_spec) + + return log_spec + + def _compute_audio_embed_size(self, audio_frames): + integer = audio_frames // self.audio_compression_rate + remainder = audio_frames % self.audio_compression_rate + result = integer + (remainder > 0).to(integer.dtype) + + integer = result // self.audio_downsample_rate + remainder = result % self.audio_downsample_rate + result = integer + (remainder > 0).to(integer.dtype) # qformer compression + + return result + + +__all__ = ["Phi4MultimodalFeatureExtractor"] diff --git a/src/transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py b/src/transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py new file mode 100644 index 0000000000..c81820ee32 --- /dev/null +++ b/src/transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py @@ -0,0 +1,263 @@ +# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Processor class for Phi4Multimodal +""" + +import math +from typing import List, Optional, Union + +import torch +from torchvision.transforms import functional as F + +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + BatchFeature, + DefaultFastImageProcessorKwargs, + Unpack, + convert_to_rgb, +) +from ...image_utils import ImageInput, make_list_of_images, valid_images +from ...utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +class Phi4MultimodalFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + image_size: Optional[int] + patch_size: Optional[int] + dynamic_hd: Optional[int] + + +class Phi4MultimodalImageProcessorFast(BaseImageProcessorFast): + r""" + Constructs a Phi4Multimodal image processor. + """ + + image_size = 448 + patch_size = 14 + dynamic_hd = 36 + image_mean = [0.5, 0.5, 0.5] + image_std = [0.5, 0.5, 0.5] + valid_init_kwargs = Phi4MultimodalFastImageProcessorKwargs + model_input_names = ["image_pixel_values", "image_sizes", "image_attention_mask"] + + def __init__(self, **kwargs: Unpack[Phi4MultimodalFastImageProcessorKwargs]): + super().__init__(**kwargs) + + def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * self.image_size * self.image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + def dynamic_preprocess(self, image, max_num=36, min_num=1): + image_size = self.image_size + patch_size = self.patch_size + mask_size = image_size // patch_size + orig_width, orig_height = image.size + + w_crop_num = math.ceil(orig_width / float(image_size)) + h_crop_num = math.ceil(orig_height / float(image_size)) + if w_crop_num * h_crop_num > max_num: + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + } + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = self.find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + else: + target_width = image_size * w_crop_num + target_height = image_size * h_crop_num + target_aspect_ratio = (w_crop_num, h_crop_num) + + # Calculate the ratio + ratio_width = target_width / orig_width + ratio_height = target_height / orig_height + if ratio_width < ratio_height: + new_size = (target_width, int(orig_height * ratio_width)) + padding_width = 0 + padding_height = target_height - int(orig_height * ratio_width) + else: + new_size = (int(orig_width * ratio_height), target_height) + padding_width = target_width - int(orig_width * ratio_height) + padding_height = 0 + + attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]), int(mask_size * target_aspect_ratio[0]))) + if padding_width >= patch_size: + attention_mask[:, -math.floor(padding_width / patch_size) :] = 0 + if padding_height >= patch_size: + attention_mask[-math.floor(padding_height / patch_size) :, :] = 0 + + if min(new_size[1], target_height) < 10 or min(new_size[0], target_width) < 10: + raise ValueError(f"the aspect ratio is very extreme {new_size}") + + image = F.resize(image, [new_size[1], new_size[0]]) + resized_img = F.pad(image, [0, 0, padding_width, padding_height], fill=[255, 255, 255]) + + return resized_img, attention_mask + + def pad_to_max_num_crops(self, images, max_crops=5): + """ + images: B x 3 x H x W, B<=max_crops + """ + B, _, H, W = images.shape + if B < max_crops: + pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device) + images = torch.cat([images, pad], dim=0) + return images + + def pad_mask_to_max_num_crops(self, masks, max_crops=5): + B, H, W = masks.shape + if B < max_crops: + pad = torch.ones(max_crops - B, H, W, dtype=masks.dtype, device=masks.device) + masks = torch.cat([masks, pad], dim=0) + return masks + + def preprocess( + self, + images: ImageInput, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + """ + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_list_of_images(images) + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + images = [convert_to_rgb(image) for image in images] + + image_size = self.image_size + patch_size = self.patch_size + mask_size = image_size // patch_size + imgs_and_masks = [self.dynamic_preprocess(image, max_num=self.dynamic_hd) for image in images] + images, image_attention_masks = [x[0] for x in imgs_and_masks], [x[1] for x in imgs_and_masks] + + images = [F.to_tensor(image) for image in images] + hd_images = [F.normalize(image, image_mean, image_std) for image in images] + global_image = [ + torch.nn.functional.interpolate( + image.unsqueeze(0).float(), + size=(image_size, image_size), + mode="bicubic", + ).to(image.dtype) + for image in hd_images + ] + + shapes = [[image.size(1), image.size(2)] for image in hd_images] + mask_shapes = [[mask.size(0), mask.size(1)] for mask in image_attention_masks] + global_attention_mask = [torch.ones((1, mask_size, mask_size)) for _ in hd_images] + + hd_images_reshape = [] + for im, (h, w) in zip(hd_images, shapes): + im = im.reshape(1, 3, h // image_size, image_size, w // image_size, image_size) + im = im.permute(0, 2, 4, 1, 3, 5) + im = im.reshape(-1, 3, image_size, image_size) + hd_images_reshape.append(im.contiguous()) + + attention_masks_reshape = [] + for mask, (h, w) in zip(image_attention_masks, mask_shapes): + mask = mask.reshape(h // mask_size, mask_size, w // mask_size, mask_size) + mask = mask.transpose(1, 2) + mask = mask.reshape(-1, mask_size, mask_size) + attention_masks_reshape.append(mask.contiguous()) + + downsample_attention_masks = [] + for mask, (h, w) in zip(attention_masks_reshape, mask_shapes): + mask = mask[:, 0::2, 0::2] + mask = mask.reshape( + h // mask_size, w // mask_size, mask_size // 2 + mask_size % 2, mask_size // 2 + mask_size % 2 + ) + mask = mask.transpose(1, 2) + mask = mask.reshape(mask.size(0) * mask.size(1), mask.size(2) * mask.size(3)) + downsample_attention_masks.append(mask) + + num_img_tokens = [ + 256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16 for mask in downsample_attention_masks + ] + + hd_images_reshape = [ + torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape) + ] + hd_masks_reshape = [ + torch.cat([_global_mask] + [_mask], dim=0) + for _global_mask, _mask in zip(global_attention_mask, attention_masks_reshape) + ] + max_crops = max([img.size(0) for img in hd_images_reshape]) + image_transformed = [self.pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape] + image_transformed = torch.stack(image_transformed, dim=0) + mask_transformed = [self.pad_mask_to_max_num_crops(mask, max_crops) for mask in hd_masks_reshape] + mask_transformed = torch.stack(mask_transformed, dim=0) + + returned_input_image_embeds = image_transformed + returned_image_sizes = torch.tensor(shapes, dtype=torch.long) + returned_image_attention_mask = mask_transformed + returned_num_img_tokens = num_img_tokens + + data = { + "image_pixel_values": returned_input_image_embeds, + "image_sizes": returned_image_sizes, + "image_attention_mask": returned_image_attention_mask, + "num_img_tokens": returned_num_img_tokens, + } + + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["Phi4MultimodalImageProcessorFast"] diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py new file mode 100644 index 0000000000..5d44fae131 --- /dev/null +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -0,0 +1,2316 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_phi4_multimodal.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out + +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from .configuration_phi4_multimodal import Phi4MultimodalAudioConfig, Phi4MultimodalConfig, Phi4MultimodalVisionConfig + + +logger = logging.get_logger(__name__) + + +class Phi4MultimodalVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +def simple_eager_attention_forward( + module: nn.Module, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Phi4MultimodalVisionAttention(nn.Module): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = True + self.attention_dropout = config.attention_dropout + + self.k_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.v_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.q_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = simple_eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1) + attn_output = self.out_proj(attn_output) + return attn_output, attn_weights + + +class Phi4MultimodalVisionEncoderLayer(nn.Module): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Phi4MultimodalVisionAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Phi4MultimodalVisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class Phi4MultimodalVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Phi4MultimodalVisionEncoderLayer`]. + + Args: + config: Phi4MultimodalVisionConfig + """ + + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [Phi4MultimodalVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +class Phi4MultimodalVisionPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Phi4MultimodalVisionConfig + base_model_prefix = "phi4_vision" + supports_gradient_checkpointing = True + + _no_split_modules = ["Phi4MultimodalVisionEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, Phi4MultimodalVisionEmbeddings): + width = ( + self.config.hidden_size + if isinstance(self.config, Phi4MultimodalVisionConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, Phi4MultimodalVisionAttention): + nn.init.normal_(module.q_proj.weight) + nn.init.normal_(module.k_proj.weight) + nn.init.normal_(module.v_proj.weight) + nn.init.normal_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, Phi4MultimodalVisionMLP): + nn.init.normal_(module.fc1.weight) + nn.init.normal_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, Phi4MultimodalVisionMultiheadAttentionPoolingHead): + nn.init.normal_(module.probe.data) + nn.init.normal_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class Phi4MultimodalVisionEmbeddings(nn.Module): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.patch_size = config.patch_size + self.num_patches_per_side = config.image_size // self.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.position_embedding = nn.Embedding(self.num_patches_per_side**2, config.hidden_size) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and no class embeddings. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] + num_positions = self.position_embedding.weight.shape[0] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + patch_pos_embed = self.position_embedding.weight.unsqueeze(0) + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: + batch_size = pixel_values.size(0) + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Phi4MultimodalVisionMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = Phi4MultimodalVisionMLP(config) + + def forward(self, hidden_state, attention_mask): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention( + query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask + )[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class Phi4MultimodalVisionModel(Phi4MultimodalVisionPreTrainedModel): + config_class = Phi4MultimodalVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__(config) + self.config = config + + self.embeddings = Phi4MultimodalVisionEmbeddings(config) + self.encoder = Phi4MultimodalVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.head = Phi4MultimodalVisionMultiheadAttentionPoolingHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.patch_embedding + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_attention_mask = torch.ones( + size=( + batch_size, + pixel_values.size(2) // self.config.patch_size, + pixel_values.size(3) // self.config.patch_size, + ), + dtype=torch.bool, + device=pixel_values.device, + ) + + hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + attention_mask = None + else: + attention_mask = ( + _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) + if not self.config._attn_implementation == "flash_attention_2" + else patch_attention_mask + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head( + hidden_state=last_hidden_state, + attention_mask=patch_attention_mask, + ) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class Phi4MultimodalImageEmbedding(nn.Module): + """Image embedding.""" + + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.config = config + self.layer_idx = config.vision_config.feature_layer + self.crop_size = config.vision_config.crop_size + self.image_dim_out = config.vision_config.hidden_size + + n_patches = config.vision_config.image_size // config.vision_config.patch_size + if n_patches % 2 != 0: + self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) + n_patches += 1 + self.num_img_tokens = (n_patches // 2) ** 2 + + self.drop = nn.Dropout(config.embd_pdrop) + self.img_processor = Phi4MultimodalVisionModel._from_config(config.vision_config) + self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) + self.img_projection_up = nn.Linear(self.image_dim_out, config.hidden_size) + self.img_projection_down = nn.Linear(config.hidden_size, config.hidden_size) + self.global_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, self.image_dim_out])) + self.sub_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out])) + + def get_img_features(self, img_embeds: torch.FloatTensor, attention_mask=None) -> torch.FloatTensor: + img_processor_output = self.img_processor( + img_embeds, patch_attention_mask=attention_mask, output_hidden_states=True + ) + img_feature = img_processor_output.hidden_states[self.layer_idx] + + patch_feature = img_feature + # reshape to 2D tensor + width = int(math.sqrt(patch_feature.size(1))) + patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) + # convert to NCHW + patch_feature = patch_feature.permute(0, 3, 1, 2) + if getattr(self, "img_processor_padding", None) is not None: + patch_feature = self.img_processor_padding(patch_feature) + patch_feature = self.image_token_compression(patch_feature) + # convert to NHWC + patch_feature = patch_feature.permute(0, 2, 3, 1) + patch_feature = patch_feature.view(-1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)) + return patch_feature + + def forward( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.Tensor, + image_pixel_values: torch.FloatTensor, + image_sizes: Optional[torch.Tensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + image_pixel_values = image_pixel_values.to(self.img_processor.embeddings.patch_embedding.weight.dtype) + + target_device = self.img_projection_up.bias.device + target_dtype = self.img_projection_up.bias.dtype + + batch_size = image_pixel_values.shape[0] + + img_features = self.get_img_features( + image_pixel_values.flatten(0, 1), + attention_mask=image_attention_mask.flatten(0, 1).to(dtype=bool, device=target_device), + ) + base_feat_size = int(np.sqrt(img_features.shape[1])) + img_features = img_features.view(batch_size, -1, base_feat_size**2, self.image_dim_out) + image_sizes = image_sizes.view(-1, 2) + + output_imgs = [] + for idx in range(batch_size): + height, width = image_sizes[idx] + height_ratio = height // self.crop_size + width_ratio = width // self.crop_size + area_ratio = height_ratio * width_ratio + + global_img = img_features[idx, :1] + global_img = global_img.reshape(1, base_feat_size, base_feat_size, self.image_dim_out).contiguous() + temporary_extensor = self.sub_img_feature_extensor.repeat(1, base_feat_size, 1, 1) + global_img = torch.cat([global_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) + + sub_img = img_features[idx, 1:] + sub_img = sub_img[:area_ratio] + sub_img = ( + sub_img.reshape(height_ratio, width_ratio, base_feat_size, base_feat_size, self.image_dim_out) + .transpose(1, 2) + .reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size, self.image_dim_out) + .contiguous() + ) + + if image_attention_mask is not None: + reshaped_image_attention_mask = ( + image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2] + .reshape(height_ratio, width_ratio, base_feat_size, base_feat_size) + .transpose(1, 2) + .reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size) + ) + useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) + useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) + sub_img = sub_img[:, :useful_height, :useful_width] + temporary_extensor = self.sub_img_feature_extensor.repeat(1, useful_height, 1, 1) + else: + temporary_extensor = self.sub_img_feature_extensor.repeat(1, height_ratio * base_feat_size, 1, 1) + + sub_img = torch.cat([sub_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) + + # Merge global and sub + output_imgs.append(torch.cat([sub_img, self.global_img_feature_extensor, global_img], dim=1)) + + img_set_tensor = [] + for output_img in output_imgs: + output_img = output_img.to(device=target_device, dtype=target_dtype) + img_feature_proj = self.img_projection_up(output_img) + img_feature_proj = nn.functional.gelu(img_feature_proj) + img_feature_proj = self.img_projection_down(img_feature_proj) + img_set_tensor.append(img_feature_proj) + + merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0) + merged_img_set_tensor = merged_img_set_tensor.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) + + with torch.no_grad(): + positions_tuple = torch.nonzero(input_ids == self.config.vision_config.image_token_id, as_tuple=True) + + # Temporarily disable autocast to avoid issue on bf16 tensors + # Ref: https://github.com/pytorch/pytorch/issues/132715 + with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): + image_embeds = inputs_embeds.index_put( + indices=positions_tuple, values=merged_img_set_tensor, accumulate=False + ) + + image_embeds = self.drop(image_embeds) + + return image_embeds + + +########################################################## AUDIO ############################################# + + +class Phi4MultimodalAudioMLP(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.act_fn = ACT2FN[config.activation] + self.gate_up_proj = nn.Linear(config.hidden_size, config.intermediate_size * 2) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + hidden_states = self.layer_norm(hidden_states) + up_states = self.gate_up_proj(hidden_states) + up_states, gate = up_states.chunk(2, dim=-1) + up_states = up_states * self.act_fn(gate) + up_states = self.dropout(up_states) + hidden_states = self.down_proj(up_states) + out = self.dropout(hidden_states) + + return out + + +class Phi4MultimodalAudioAttention(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.dropout_rate + self.is_causal = True + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + **kwargs, + ): + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = simple_eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + +class Phi4MultimodalAudioDepthWiseSeperableConv1d(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig, padding: int = 0): + super().__init__() + self.dw_conv = nn.Conv1d( + config.hidden_size, + config.hidden_size * config.depthwise_multiplier, + config.kernel_size, + 1, + padding=padding, + groups=config.hidden_size, + ) + self.pw_conv = nn.Conv1d( + config.hidden_size * config.depthwise_multiplier, config.depthwise_seperable_out_channel, 1, 1, 0 + ) + + def forward(self, hidden_states): + return self.pw_conv(self.dw_conv(hidden_states)) + + +class Phi4MultimodalAudioGluPointWiseConv(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.output_dim = config.ext_pw_out_channel + + self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel * 2, kernel_size=1, stride=1) + self.glu_act = ACT2FN[config.conv_glu_type] + self.b1 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1)) + self.b2 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1)) + + def forward(self, hidden_states): + # we assume the input always has the #channel (#dim) in the last dimension of the + # tensor, so need to switch the dimension first for 1D-Conv case + hidden_states = hidden_states.permute([0, 2, 1]) + hidden_states = self.ext_pw_conv_1d(hidden_states) + out = hidden_states[:, 0 : self.output_dim, :] + self.b1 + out = out * self.glu_act(hidden_states[:, self.output_dim : self.output_dim * 2, :] + self.b2) + return out.permute([0, 2, 1]) + + +class Phi4MultimodalAudioConvModule(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.kernel_size = config.kernel_size + + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.glu = Phi4MultimodalAudioGluPointWiseConv(config) + self.dw_sep_conv_1d = Phi4MultimodalAudioDepthWiseSeperableConv1d(config, padding=config.kernel_size - 1) + self.act = ACT2FN[config.conv_activation] + self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel, kernel_size=1, stride=1) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.glu(self.layer_norm(hidden_states)) + hidden_states = self.dw_sep_conv_1d(hidden_states.permute([0, 2, 1])) + + if self.kernel_size > 1: + hidden_states = hidden_states[:, :, : -(self.kernel_size - 1)] + + hidden_states = self.act(hidden_states) + hidden_states = self.ext_pw_conv_1d(hidden_states) + out = self.dropout(hidden_states.permute([0, 2, 1])) + return out + + +class Phi4MultimodalAudioConformerEncoderLayer(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + + self.feed_forward_in = Phi4MultimodalAudioMLP(config) + self.self_attn = Phi4MultimodalAudioAttention(config) + self.conv = Phi4MultimodalAudioConvModule(config) + self.feed_forward_out = Phi4MultimodalAudioMLP(config) + self.layer_norm_att = nn.LayerNorm(config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ): + residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states) + hidden_states = self.layer_norm_att(residual) + + hidden_states = residual + self.self_attn(hidden_states, attention_mask) + hidden_states = hidden_states + self.conv(hidden_states) + hidden_states = hidden_states + 0.5 * self.feed_forward_out(hidden_states) + + out = self.layer_norm(hidden_states) + + return out + + +class Phi4MultimodalAudioNemoConvSubsampling(torch.nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.subsampling_factor = config.time_reduction + self.sampling_num = int(math.log(self.subsampling_factor, 2)) + self.act_fn = ACT2FN[config.nemo_activation] + conv_channels = config.nemo_conv_channels + + layers = [ + nn.Conv2d(1, conv_channels, kernel_size=3, stride=2, padding=1), + self.act_fn, + ] + for _ in range(self.sampling_num - 1): + layers.extend( + [ + nn.Conv2d(conv_channels, conv_channels, kernel_size=3, stride=2, padding=1, groups=conv_channels), + nn.Conv2d(conv_channels, conv_channels, kernel_size=1, stride=1, padding=0, groups=1), + self.act_fn, + ] + ) + + # Aggregate the layers + self.conv = torch.nn.Sequential(*layers) + self.out = torch.nn.Linear(conv_channels * config.nemo_final_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]): + # Unsqueeze Channel Axis + hidden_states = hidden_states.unsqueeze(1) + hidden_states = self.conv(hidden_states) + + # Flatten Channel and Frequency Axes + b, _, t, _ = hidden_states.size() + hidden_states = self.out(hidden_states.transpose(1, 2).reshape(b, t, -1)) + + if mask is None: + return hidden_states, None + + max_audio_length = hidden_states.shape[1] + feature_lens = mask.sum(1) + padding_length = torch.ceil(feature_lens / self.subsampling_factor) + arange_ = torch.arange(0, max_audio_length, device=hidden_states.device) + pad_mask = arange_.expand(padding_length.size(0), -1) < padding_length.unsqueeze(1) + return hidden_states, pad_mask.unsqueeze(1) + + +class Phi4MultimodalAudioRelativeAttentionBias(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + + self.max_distance = config.bias_max_distance + self.symmetric = config.bias_symmetric + self.num_buckets = self.max_distance + if not config.bias_symmetric: + self.num_buckets *= 2 + self.bias_values = nn.Embedding(self.num_buckets, config.num_attention_heads) + + def forward(self, x): + # instantiate bias compatible with shape of x + max_pos = x.size(1) + context_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[:, None] + memory_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + # clipping to a maximum distance using ops that play well with ONNX export + relative_position = relative_position.masked_fill(relative_position < -self.max_distance, -self.max_distance) + relative_position = relative_position.masked_fill( + relative_position > self.max_distance - 1, self.max_distance - 1 + ) + + # mapping from relative position to index in the bias parameter + bias_idx = relative_position + bias_idx = bias_idx.abs() if self.symmetric else bias_idx + self.num_buckets // 2 + + att_bias = self.bias_values(bias_idx) + att_bias = att_bias.permute(2, 0, 1).unsqueeze(0) + + return att_bias + + +class Phi4MultimodalAudioMeanVarianceNormLayer(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.register_buffer("global_mean", torch.zeros(config.input_size)) + self.register_buffer("global_invstd", torch.ones(config.input_size)) + + def forward(self, x): + return (x - self.global_mean) * self.global_invstd + + +class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): + config_class = Phi4MultimodalAudioConfig + supports_gradient_checkpointing = True + _no_split_modules = ["Phi4MultimodalAudioConformerEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +def unfold_tensor(tensor, max_seq_len): + """ + For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, + this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. + Args: + tensor: N, T, D + """ + _, _, D = tensor.shape + tensor = tensor.transpose(-1, -2) + # N x D x 1 x T => N x (D x max_seq_len) x T' + tensor = F.unfold(tensor[..., None, :], kernel_size=(1, max_seq_len), stride=(1, max_seq_len)) + + new_bsz, _, slen = tensor.shape + tensor = tensor.view(new_bsz, -1, max_seq_len, slen) + tensor = tensor.permute(0, 3, 2, 1) + tensor = tensor.view(-1, max_seq_len, D).contiguous() + return tensor + + +def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): + """ + The function is very important for Transformer Transducer Streaming mode + Args: + xs_len (int): sequence length + chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] + left_window (int): how many left chunks can be seen + right_window (int): how many right chunks can be seen. It is used for chunk overlap model. + Returns: + mask (torch.Tensor): a mask tensor for streaming model + """ + chunk_start_idx = torch.Tensor(chunk_start_idx).long() + start_pad = torch.nn.functional.pad( + chunk_start_idx, (1, 0) + ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] + end_pad = torch.nn.functional.pad( + chunk_start_idx, (0, 1), value=x_len + ) # append x_len to the end, so it becomes [0,18,36,48, x_len] + seq_range = torch.arange(0, x_len).unsqueeze(-1) + idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] + seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + idx_right = idx + right_window + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + boundary_right = end_pad[idx_right] + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + return mask_left & mask_right + + +class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__(config) + self.config = config + + self.encoder_embedding = Phi4MultimodalAudioMeanVarianceNormLayer(config) + self.embed = Phi4MultimodalAudioNemoConvSubsampling(config) + self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias(config) + self.encoders = nn.ModuleList( + [Phi4MultimodalAudioConformerEncoderLayer(config) for _ in range(config.num_blocks)] + ) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): + # Create mask matrix for streaming + # S stores start index. if chunksize is 18, s is [0,18,36,....] + chunk_start_idx = np.arange(0, seq_len, chunk_size) + # avoid randomness when run evaluation or decoding + if self.training and np.random.rand() > 0.5: + # Either first or last chunk is not complete. + # If only the last one is not complete, EOS is not effective + chunk_start_idx = seq_len - chunk_start_idx + chunk_start_idx = chunk_start_idx[::-1] + chunk_start_idx = chunk_start_idx[:-1] + chunk_start_idx = np.insert(chunk_start_idx, 0, 0) + + enc_streaming_mask = ( + adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk) + .unsqueeze(0) + .expand([batch_size, -1, -1]) + ) + return enc_streaming_mask + + def forward_embeddings(self, hidden_states, masks): + """Forwarding the inputs through the top embedding layers""" + seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction) + if seq_len <= 0: + raise ValueError( + f"The squence length after time reduction is invalid: {seq_len}. Your input feature is too short." + ) + + batch_size = hidden_states.shape[0] + + enc_streaming_mask = self._streaming_mask(seq_len, batch_size, self.config.chunk_size, self.config.left_chunk) + enc_streaming_mask = enc_streaming_mask.to(hidden_states.device) + + hidden_states, masks = self.embed(hidden_states, masks) + + streaming_mask = enc_streaming_mask + if streaming_mask is not None and masks is not None: + hs_mask = masks & streaming_mask + elif masks is not None: + hs_mask = masks + else: + hs_mask = streaming_mask + + return hidden_states, hs_mask, masks + + def calculate_hs_mask(self, hidden_states, device, mask): + max_audio_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] + enc_streaming_mask = self._streaming_mask( + max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk + ) + enc_streaming_mask = enc_streaming_mask.to(device) + if mask is None: + return enc_streaming_mask + + feature_lens = mask.sum(1) + padding_length = feature_lens + pad_mask = torch.arange(0, max_audio_length, device=device).expand( + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) + pad_mask = pad_mask.unsqueeze(1) + pad_mask = pad_mask & enc_streaming_mask + return pad_mask + + def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]): + hidden_states = self.encoder_embedding(hidden_states) + hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask) + + unfolded = False + bs, seq_len, _ = hidden_states.shape + max_seq_len = 500 # maxium position for absolute positional encoding + if seq_len > max_seq_len: + # audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len + unfolded = True + # the unfold op will drop residual frames, pad it to the multiple of max_seq_len + if seq_len % max_seq_len > 0: + chunk_pad_size = max_seq_len - (seq_len % max_seq_len) + else: + chunk_pad_size = 0 + if chunk_pad_size > 0: + hidden_states_pad = F.pad(hidden_states, (0, 0, 0, chunk_pad_size), "constant", 0) + hidden_states = hidden_states_pad.to(hidden_states.device) + + hidden_states = unfold_tensor(hidden_states, max_seq_len) + masks_unfold = None + if mask is not None: + # revise hs_mask here because the previous calculated hs_mask did not consider extra pad + subsampled_pad_mask = mask.squeeze(1) # [bz, subsampled_unmask_seq_len] + extra_padded_subsamlped_pad_mask = F.pad( + subsampled_pad_mask, (0, chunk_pad_size), "constant", False + ) # extra padding to the pad mask + extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + masks_unfold = unfold_tensor( + extra_padded_subsamlped_pad_mask, max_seq_len + ) # unfold the pad mask like we did to the input tensor + masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor + hs_mask = self.calculate_hs_mask( + hidden_states, hidden_states.device, masks_unfold + ) # calculate hs_mask based on the unfolded pad mask + + relative_attention_bias = self.relative_attention_bias_layer(hidden_states) + attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias + + for layer in self.encoders: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + ) + else: + hidden_states = layer(hidden_states, attention_mask) + + if unfolded: + embed_dim = hidden_states.shape[-1] + hidden_states = hidden_states.reshape(bs, -1, embed_dim) + # if we ever padded before unfolding, we need to remove the padding + if chunk_pad_size > 0: + hidden_states = hidden_states[:, :-chunk_pad_size, :] + + return hidden_states + + +class Phi4MultimodalAudioEmbedding(nn.Module): + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.config = config + self.layer_idx = config.audio_config.feature_layer + + self.drop = nn.Dropout(config.embd_pdrop) + self.encoder = Phi4MultimodalAudioModel._from_config(config.audio_config) + self.up_proj_for_speech = nn.Linear( + config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size + ) + self.down_proj_for_speech = nn.Linear(config.hidden_size, config.hidden_size) + self.up_proj_for_vision_speech = nn.Linear( + config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size + ) + self.down_proj_for_vision_speech = nn.Linear(config.hidden_size, config.hidden_size) + + def forward( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.Tensor, + audio_input_features: torch.FloatTensor, + audio_embed_sizes=None, + audio_attention_mask=None, + audio_projection_mode="speech", + ) -> torch.FloatTensor: + with torch.no_grad(): + positions_tuple = torch.nonzero(input_ids == self.config.audio_config.audio_token_id, as_tuple=True) + + up_proj = self.up_proj_for_speech if audio_projection_mode == "speech" else self.up_proj_for_vision_speech + down_proj = ( + self.down_proj_for_speech if audio_projection_mode == "speech" else self.down_proj_for_vision_speech + ) + + target_device = up_proj.bias.device + target_dtype = up_proj.bias.dtype + + audio_input_features = audio_input_features.to(device=target_device, dtype=target_dtype) + + audio_encoder_hidden_states = self.encoder(audio_input_features, audio_attention_mask) + audio_encoder_hidden_states = up_proj(audio_encoder_hidden_states) + audio_encoder_hidden_states = nn.functional.gelu(audio_encoder_hidden_states) + audio_embeds = down_proj(audio_encoder_hidden_states) + + merged_audio_embeds = torch.cat( + [audio_embeds[i, : audio_embed_sizes[i], :] for i in range(len(audio_embed_sizes))], dim=0 + ) + merged_audio_embeds = merged_audio_embeds.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) + # Temporarily disable autocast to avoid issue on bf16 tensors + # Ref: https://github.com/pytorch/pytorch/issues/132715 + with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): + audio_embeds = inputs_embeds.index_put( + indices=positions_tuple, values=merged_audio_embeds, accumulate=False + ) + + audio_embeds = self.drop(audio_embeds) + + return audio_embeds + + +class Phi4MultimodalRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Phi4MultimodalRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Phi4MultimodalMLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) + k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) + return q_embed, k_embed + + +class Phi4MultimodalAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi4MultimodalConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + qkv = self.qkv_proj(hidden_states) + query_pos = self.config.num_attention_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Phi4MultimodalDecoderLayer(nn.Module): + def __init__(self, config: Phi4MultimodalConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Phi4MultimodalAttention(config=config, layer_idx=layer_idx) + self.mlp = Phi4MultimodalMLP(config) + self.input_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.config = config + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_value (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Phi4MultimodalFeatureEmbedding(nn.Module): + """Image-audio embedding.""" + + def __init__(self, config: Phi4MultimodalConfig) -> None: + super().__init__() + self.config = config + self.image_token_id = config.vision_config.image_token_id + self.audio_token_id = config.audio_config.audio_token_id + self.image_embed = Phi4MultimodalImageEmbedding(config) + self.audio_embed = Phi4MultimodalAudioEmbedding(config) + + def forward( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.Tensor, + image_pixel_values: Optional[torch.FloatTensor] = None, + audio_input_features: Optional[torch.FloatTensor] = None, + image_sizes=None, + image_attention_mask=None, + audio_embed_sizes=None, + audio_attention_mask=None, + ) -> torch.FloatTensor: + with torch.no_grad(): + image_position_mask = (input_ids == self.config.vision_config.image_token_id).unsqueeze(-1) + non_image_position_mask = ~image_position_mask + + image_embeds = None + audio_embeds = None + if image_pixel_values is not None and (input_ids == self.image_token_id).any(): + image_embeds = self.image_embed( + input_ids, + inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + ) + if audio_input_features is not None and (input_ids == self.audio_token_id).any(): + audio_projection_mode = "vision" if image_pixel_values is not None else "speech" + audio_embeds = self.audio_embed( + input_ids, + inputs_embeds, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + audio_projection_mode=audio_projection_mode, + ) + + # merge image and audio + if image_embeds is not None and audio_embeds is not None: + inputs_embeds = image_embeds * image_position_mask + audio_embeds * non_image_position_mask + elif image_embeds is not None: + inputs_embeds = image_embeds + elif audio_embeds is not None: + inputs_embeds = audio_embeds + + return inputs_embeds + + +class Phi4MultimodalRotaryEmbedding(nn.Module): + def __init__(self, config: Phi4MultimodalConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + elif self.rope_type == "longrope": + self._longrope_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def _longrope_frequency_update(self, position_ids, device): + """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise.""" + seq_len = torch.max(position_ids) + 1 + if hasattr(self.config, "original_max_position_embeddings"): + original_max_position_embeddings = self.config.original_max_position_embeddings + else: + original_max_position_embeddings = self.config.max_position_embeddings + if seq_len > original_max_position_embeddings: + if not hasattr(self, "long_inv_freq"): + self.long_inv_freq, _ = self.rope_init_fn( + self.config, device, seq_len=original_max_position_embeddings + 1 + ) + self.register_buffer("inv_freq", self.long_inv_freq, persistent=False) + else: + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + + +PHI4_MULTIMODAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Phi4MultimodalConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Phi4Multimodal Model outputting raw hidden-states without any specific head on top.", + PHI4_MULTIMODAL_START_DOCSTRING, +) +class Phi4MultimodalPreTrainedModel(PreTrainedModel): + config_class = Phi4MultimodalConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Phi4MultimodalDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + _version = "0.0.5" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`)`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + See our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + image_pixel_values (`torch.FloatTensor`, *optional*): + If the input contains images, these correspond to the pixel values after transformations (as returned by + the Processor) + image_sizes (`torch.LongTensor`, *optional*): + If the input contains images, these correspond to size of each image. + image_attention_mask (`torch.LongTensor`, *optional*): + Attention mask for the images. + audio_input_features (`torch.FloatTensor`, *optional*): + If the input contains audio samples, these correspond to the values after transformation (as returned by + the Processor). + audio_embed_sizes (`torch.Tensor`, *optional*): + Size of the audio inputs. + audio_attention_mask (`torch.Tensor, *optional*): + Attention mask for the audio inputs. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Phi4Multimodal Model outputting raw hidden-states without any specific head on top.", + PHI4_MULTIMODAL_START_DOCSTRING, +) +class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi4MultimodalMMDecoderLayer`] + Args: + config: Phi4MultimodalMMConfig + """ + + def __init__(self, config: Phi4MultimodalConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = nn.ModuleList( + [Phi4MultimodalDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Phi4MultimodalRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + self.embed_dropout = nn.Dropout(config.embd_pdrop) + + self.embed_tokens_extend = Phi4MultimodalFeatureEmbedding(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + image_pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + image_attention_mask=None, + audio_input_features: Optional[torch.FloatTensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens_extend( + input_ids, + inputs_embeds, + image_pixel_values=image_pixel_values, + audio_input_features=audio_input_features, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Phi4Multimodal. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Phi4MultimodalConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Phi4MultimodalConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Phi4MultimodalModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=Phi4MultimodalConfig) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + image_pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + image_attention_mask=None, + audio_input_features: Optional[torch.FloatTensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + Returns: + + Example: + ```python + >>> from transformers import AutoTokenizer, Phi4MultimodalForCausalLM + >>> model = Phi4MultimodalForCausalLM.from_pretrained("TBA") + >>> tokenizer = AutoTokenizer.from_pretrained("TBA") + >>> prompt = "This is an example script ." + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + image_pixel_values=None, + image_sizes=None, + image_attention_mask=None, + audio_input_features=None, + audio_embed_sizes=None, + audio_attention_mask=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=0, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + +__all__ = [ + "Phi4MultimodalAudioPreTrainedModel", + "Phi4MultimodalAudioModel", + "Phi4MultimodalVisionPreTrainedModel", + "Phi4MultimodalVisionModel", + "Phi4MultimodalPreTrainedModel", + "Phi4MultimodalModel", + "Phi4MultimodalForCausalLM", +] diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py new file mode 100644 index 0000000000..06424941ec --- /dev/null +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -0,0 +1,1851 @@ +# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask + +from ...activations import ACT2FN +from ...cache_utils import DynamicCache +from ...configuration_utils import PretrainedConfig +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...utils import ( + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ..phi3.configuration_phi3 import Phi3Config +from ..phi3.modeling_phi3 import Phi3DecoderLayer, Phi3ForCausalLM, Phi3Model, Phi3RMSNorm +from ..siglip.configuration_siglip import SiglipVisionConfig +from ..siglip.modeling_siglip import ( + SiglipEncoder, + SiglipEncoderLayer, + SiglipMLP, + SiglipMultiheadAttentionPoolingHead, + SiglipPreTrainedModel, + SiglipVisionEmbeddings, + default_flax_embed_init, + lecun_normal_, +) + + +logger = logging.get_logger(__name__) + + +class Phi4MultimodalVisionConfig(SiglipVisionConfig): + r""" + This is the configuration class to store the configuration of a [`Phi4MultimodalVisionModel`]. It is used to instantiate a + Phi4Multimodal vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1152): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 4304): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 27): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 448): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + crop_size (`int`, *optional*, defaults to 448): + Crop size for the input images. + image_token_id (`int`, *optional*, defaults to 200010): + The image token id. + feature_layer (`int`, *optional*, defaults to -2): + The index of the layer of the encoder from which to extract image features. + + Example: + + ```python + >>> from transformers import Phi4MultimodalVisionConfig + + >>> # Initializing a Phi4MultimodalVisionConfig with microsoft/Phi-4-multimodal-instruct style configuration + >>> configuration = Phi4MultimodalVisionConfig() + ```""" + + model_type = "phi4_multimodal_vision" + + def __init__( + self, + hidden_size=1152, + intermediate_size=4304, + num_hidden_layers=27, + num_attention_heads=16, + num_channels=3, + image_size=448, + patch_size=14, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + crop_size: int = 448, + image_token_id: int = 200010, + feature_layer: int = -2, + **kwargs, + ): + super().__init__( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_channels=num_channels, + image_size=image_size, + patch_size=patch_size, + hidden_act=hidden_act, + layer_norm_eps=layer_norm_eps, + attention_dropout=attention_dropout, + **kwargs, + ) + self.crop_size = crop_size + self.image_token_id = image_token_id + self.feature_layer = feature_layer + + +class Phi4MultimodalAudioConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi4MultimodalAudioModel`]. It is used to instantiate a + Phi4Multimodal audio encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers. + intermediate_size (`int`, *optional*, defaults to 1536): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_blocks (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + activation (`str`, *optional*, defaults to `"swish"`): + The non-linear activation function in the MLPs. + chunk_size (`int`, *optional*, defaults to -1): + The chunk size to create the masks. + left_chunk (`int`, *optional*, defaults to 18): + The left chunk to create the masks. + dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout ratio. + ext_pw_out_channel (`int`, *optional*, defaults to 1024): + Number of out channels in the point-wise conv modules. + depthwise_seperable_out_channel (`int`, *optional*, defaults to 1024): + Number of out channels in the depth-wise separable conv modules. + depthwise_multiplier (`int`, *optional*, defaults to 1): + Input size multiplier for the depth-wise separable conv modules. + kernel_size (`int`, *optional*, defaults to 3): + Kernel size for the depth-wise separable conv modules. + conv_activation (`str`, *optional*, defaults to `"swish"`): + The non-linear activation function in the conv modules. + input_size (`int`, *optional*, defaults to 80): + Input size for the audio model. + conv_glu_type (`str`, *optional*, defaults to `"swish"`): + The non-linear activation function in the point-wise conv modules. + time_reduction (`int`, *optional*, defaults to 8): + Time reduction (subsampling factor). + bias_max_distance (`int`, *optional*, defaults to 1000): + Max distance for the relative attention bias module. + bias_symmetric (`bool`, *optional*, defaults to `False`): + Whether the relative attention bias should be symmetric or not. + nemo_activation (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in the nemo conv modules. + nemo_conv_channels (`int`, *optional*, defaults to 1024): + Number of channels in the nemo conv modules. + downsample_rate (`int`, *optional*, defaults to 1): + Downsample rate for the audio feature extractor. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + audio_token_id (`int`, *optional*, defaults to 200011): + The audio token id. + feature_layer (`int`, *optional*, defaults to -2): + The index of the layer of the encoder from which to extract audio features. + + Example: + + ```python + >>> from transformers import Phi4MultimodalAudioConfig + + >>> # Initializing a Phi4MultimodalAudioConfig with microsoft/Phi-4-multimodal-instruct style configuration + >>> configuration = Phi4MultimodalAudioConfig() + ```""" + + model_type = "phi4_multimodal_audio" + + def __init__( + self, + hidden_size: int = 1024, + intermediate_size: int = 1536, + num_blocks: int = 24, + num_attention_heads: int = 16, + activation: str = "swish", + chunk_size: int = -1, + left_chunk: int = 18, + dropout_rate: float = 0.0, + ext_pw_out_channel: int = 1024, + depthwise_seperable_out_channel: int = 1024, + depthwise_multiplier: int = 1, + kernel_size: int = 3, + conv_activation: str = "swish", + input_size: int = 80, + conv_glu_type: str = "swish", + time_reduction: int = 8, + bias_max_distance: int = 1000, + bias_symmetric: bool = False, + nemo_activation: str = "relu", + nemo_conv_channels: int = 1024, + downsample_rate: int = 1, + initializer_range: float = 0.02, + audio_token_id: int = 200011, + feature_layer: int = -2, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.activation = activation + self.chunk_size = chunk_size + self.left_chunk = left_chunk + self.num_blocks = num_blocks + self.dropout_rate = dropout_rate + self.ext_pw_out_channel = ext_pw_out_channel + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + self.depthwise_multiplier = depthwise_multiplier + self.kernel_size = kernel_size + self.conv_activation = conv_activation + self.input_size = input_size + self.conv_glu_type = conv_glu_type + self.time_reduction = time_reduction + self.bias_max_distance = bias_max_distance + self.bias_symmetric = bias_symmetric + self.nemo_activation = nemo_activation + self.nemo_conv_channels = nemo_conv_channels + self.downsample_rate = downsample_rate + self.audio_token_id = audio_token_id + self.initializer_range = initializer_range + self.feature_layer = feature_layer + + if time_reduction % 2 != 0: + raise ValueError("`time_reduction` should be a multiple of 2!") + length = input_size + for _ in range(int(math.log(time_reduction, 2))): + length = math.floor((length - 1) / 2 + 1) + self.nemo_final_size = length + + +class Phi4MultimodalConfig(Phi3Config): + r""" + This is the configuration class to store the configuration of a [`Phi4MultimodalModel`]. It is used to instantiate a + Phi4Multimodal model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 200064): + Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3Model`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + partial_rotary_factor (`float`, *optional*, defaults to `1.0`): + Percentage of the query and keys which will have rotary embedding. Must be between 0.0 and 1.0. + bos_token_id (`int`, *optional*, defaults to 199999): + The id of the "beginning-of-sequence" token. + eos_token_id (`int` or `list[int]`, *optional*, defaults to `[199999, 200020]`): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 199999): + The id of the padding token. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + vision_config (`Phi4MultimodalVisionConfig` or `dict`, *optional*): + The vision config for the underlying image embedding model. If not provided, will default to the configuration + used to instantiate a model similar in architecture as + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct). + audio_config (`Phi4MultimodalAudioConfig` or `dict`, *optional*): + The audio config for the underlying audio embedding model. If not provided, will default to the configuration + used to instantiate a model similar in architecture as + [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct). + + Example: + + ```python + >>> from transformers import Phi4MultimodalModel, Phi4MultimodalConfig + + >>> # Initializing a Phi4Multimodal style configuration + >>> configuration = Phi4MultimodalConfig.from_pretrained("microsoft/Phi-4-multimodal-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi4MultimodalModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + sub_configs = {"audio_config": Phi4MultimodalAudioConfig, "vision_config": Phi4MultimodalVisionConfig} + + def __init__( + self, + vocab_size=200064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=131072, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=1, + bos_token_id=199999, + eos_token_id=[199999, 200020], + pad_token_id=199999, + original_max_position_embeddings=4096, + sliding_window=None, + vision_config=None, + audio_config=None, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + resid_pdrop=resid_pdrop, + embd_pdrop=embd_pdrop, + attention_dropout=attention_dropout, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + rms_norm_eps=rms_norm_eps, + use_cache=use_cache, + tie_word_embeddings=tie_word_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + partial_rotary_factor=partial_rotary_factor, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + original_max_position_embeddings=original_max_position_embeddings, + sliding_window=sliding_window, + **kwargs, + ) + + if isinstance(vision_config, dict): + vision_config = Phi4MultimodalVisionConfig(**vision_config) + elif vision_config is None: + Phi4MultimodalVisionConfig() + self.vision_config = vision_config + + if isinstance(audio_config, dict): + audio_config = Phi4MultimodalAudioConfig(**audio_config) + elif vision_config is None: + audio_config = Phi4MultimodalAudioConfig() + self.audio_config = audio_config + + +class Phi4MultimodalVisionMLP(SiglipMLP): + pass + + +def simple_eager_attention_forward( + module: nn.Module, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Phi4MultimodalVisionAttention(nn.Module): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = True + self.attention_dropout = config.attention_dropout + + self.k_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.v_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.q_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = simple_eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1) + attn_output = self.out_proj(attn_output) + return attn_output, attn_weights + + +class Phi4MultimodalVisionEncoderLayer(SiglipEncoderLayer): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__(config) + self.self_attn = Phi4MultimodalVisionAttention(config) + self.mlp = Phi4MultimodalVisionMLP(config) + + +class Phi4MultimodalVisionEncoder(SiglipEncoder): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.layers = nn.ModuleList( + [Phi4MultimodalVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + + +class Phi4MultimodalVisionPreTrainedModel(SiglipPreTrainedModel): + config_class = Phi4MultimodalVisionConfig + base_model_prefix = "phi4_vision" + supports_gradient_checkpointing = True + + _no_split_modules = ["Phi4MultimodalVisionEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, Phi4MultimodalVisionEmbeddings): + width = ( + self.config.hidden_size + if isinstance(self.config, Phi4MultimodalVisionConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, Phi4MultimodalVisionAttention): + nn.init.normal_(module.q_proj.weight) + nn.init.normal_(module.k_proj.weight) + nn.init.normal_(module.v_proj.weight) + nn.init.normal_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, Phi4MultimodalVisionMLP): + nn.init.normal_(module.fc1.weight) + nn.init.normal_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, Phi4MultimodalVisionMultiheadAttentionPoolingHead): + nn.init.normal_(module.probe.data) + nn.init.normal_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class Phi4MultimodalVisionEmbeddings(SiglipVisionEmbeddings, nn.Module): + def __init__(self, config: Phi4MultimodalVisionConfig): + nn.Module.__init__() + self.config = config + self.patch_size = config.patch_size + self.num_patches_per_side = config.image_size // self.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.position_embedding = nn.Embedding(self.num_patches_per_side**2, config.hidden_size) + + def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: + batch_size = pixel_values.size(0) + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Phi4MultimodalVisionMultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__(config) + self.mlp = Phi4MultimodalVisionMLP(config) + + def forward(self, hidden_state, attention_mask): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention( + query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask + )[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class Phi4MultimodalVisionModel(Phi4MultimodalVisionPreTrainedModel): + config_class = Phi4MultimodalVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__(config) + self.config = config + + self.embeddings = Phi4MultimodalVisionEmbeddings(config) + self.encoder = Phi4MultimodalVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.head = Phi4MultimodalVisionMultiheadAttentionPoolingHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.patch_embedding + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_attention_mask = torch.ones( + size=( + batch_size, + pixel_values.size(2) // self.config.patch_size, + pixel_values.size(3) // self.config.patch_size, + ), + dtype=torch.bool, + device=pixel_values.device, + ) + + hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + attention_mask = None + else: + attention_mask = ( + _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) + if not self.config._attn_implementation == "flash_attention_2" + else patch_attention_mask + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head( + hidden_state=last_hidden_state, + attention_mask=patch_attention_mask, + ) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class Phi4MultimodalImageEmbedding(nn.Module): + """Image embedding.""" + + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.config = config + self.layer_idx = config.vision_config.feature_layer + self.crop_size = config.vision_config.crop_size + self.image_dim_out = config.vision_config.hidden_size + + n_patches = config.vision_config.image_size // config.vision_config.patch_size + if n_patches % 2 != 0: + self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) + n_patches += 1 + self.num_img_tokens = (n_patches // 2) ** 2 + + self.drop = nn.Dropout(config.embd_pdrop) + self.img_processor = Phi4MultimodalVisionModel._from_config(config.vision_config) + self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) + self.img_projection_up = nn.Linear(self.image_dim_out, config.hidden_size) + self.img_projection_down = nn.Linear(config.hidden_size, config.hidden_size) + self.global_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, self.image_dim_out])) + self.sub_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out])) + + def get_img_features(self, img_embeds: torch.FloatTensor, attention_mask=None) -> torch.FloatTensor: + img_processor_output = self.img_processor( + img_embeds, patch_attention_mask=attention_mask, output_hidden_states=True + ) + img_feature = img_processor_output.hidden_states[self.layer_idx] + + patch_feature = img_feature + # reshape to 2D tensor + width = int(math.sqrt(patch_feature.size(1))) + patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) + # convert to NCHW + patch_feature = patch_feature.permute(0, 3, 1, 2) + if getattr(self, "img_processor_padding", None) is not None: + patch_feature = self.img_processor_padding(patch_feature) + patch_feature = self.image_token_compression(patch_feature) + # convert to NHWC + patch_feature = patch_feature.permute(0, 2, 3, 1) + patch_feature = patch_feature.view(-1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)) + return patch_feature + + def forward( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.Tensor, + image_pixel_values: torch.FloatTensor, + image_sizes: Optional[torch.Tensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + image_pixel_values = image_pixel_values.to(self.img_processor.embeddings.patch_embedding.weight.dtype) + + target_device = self.img_projection_up.bias.device + target_dtype = self.img_projection_up.bias.dtype + + batch_size = image_pixel_values.shape[0] + + img_features = self.get_img_features( + image_pixel_values.flatten(0, 1), + attention_mask=image_attention_mask.flatten(0, 1).to(dtype=bool, device=target_device), + ) + base_feat_size = int(np.sqrt(img_features.shape[1])) + img_features = img_features.view(batch_size, -1, base_feat_size**2, self.image_dim_out) + image_sizes = image_sizes.view(-1, 2) + + output_imgs = [] + for idx in range(batch_size): + height, width = image_sizes[idx] + height_ratio = height // self.crop_size + width_ratio = width // self.crop_size + area_ratio = height_ratio * width_ratio + + global_img = img_features[idx, :1] + global_img = global_img.reshape(1, base_feat_size, base_feat_size, self.image_dim_out).contiguous() + temporary_extensor = self.sub_img_feature_extensor.repeat(1, base_feat_size, 1, 1) + global_img = torch.cat([global_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) + + sub_img = img_features[idx, 1:] + sub_img = sub_img[:area_ratio] + sub_img = ( + sub_img.reshape(height_ratio, width_ratio, base_feat_size, base_feat_size, self.image_dim_out) + .transpose(1, 2) + .reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size, self.image_dim_out) + .contiguous() + ) + + if image_attention_mask is not None: + reshaped_image_attention_mask = ( + image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2] + .reshape(height_ratio, width_ratio, base_feat_size, base_feat_size) + .transpose(1, 2) + .reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size) + ) + useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) + useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) + sub_img = sub_img[:, :useful_height, :useful_width] + temporary_extensor = self.sub_img_feature_extensor.repeat(1, useful_height, 1, 1) + else: + temporary_extensor = self.sub_img_feature_extensor.repeat(1, height_ratio * base_feat_size, 1, 1) + + sub_img = torch.cat([sub_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) + + # Merge global and sub + output_imgs.append(torch.cat([sub_img, self.global_img_feature_extensor, global_img], dim=1)) + + img_set_tensor = [] + for output_img in output_imgs: + output_img = output_img.to(device=target_device, dtype=target_dtype) + img_feature_proj = self.img_projection_up(output_img) + img_feature_proj = nn.functional.gelu(img_feature_proj) + img_feature_proj = self.img_projection_down(img_feature_proj) + img_set_tensor.append(img_feature_proj) + + merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0) + merged_img_set_tensor = merged_img_set_tensor.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) + + with torch.no_grad(): + positions_tuple = torch.nonzero(input_ids == self.config.vision_config.image_token_id, as_tuple=True) + + # Temporarily disable autocast to avoid issue on bf16 tensors + # Ref: https://github.com/pytorch/pytorch/issues/132715 + with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): + image_embeds = inputs_embeds.index_put( + indices=positions_tuple, values=merged_img_set_tensor, accumulate=False + ) + + image_embeds = self.drop(image_embeds) + + return image_embeds + + +########################################################## AUDIO ############################################# + + +class Phi4MultimodalAudioMLP(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.act_fn = ACT2FN[config.activation] + self.gate_up_proj = nn.Linear(config.hidden_size, config.intermediate_size * 2) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + hidden_states = self.layer_norm(hidden_states) + up_states = self.gate_up_proj(hidden_states) + up_states, gate = up_states.chunk(2, dim=-1) + up_states = up_states * self.act_fn(gate) + up_states = self.dropout(up_states) + hidden_states = self.down_proj(up_states) + out = self.dropout(hidden_states) + + return out + + +class Phi4MultimodalAudioAttention(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.dropout_rate + self.is_causal = True + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + **kwargs, + ): + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = simple_eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + +class Phi4MultimodalAudioDepthWiseSeperableConv1d(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig, padding: int = 0): + super().__init__() + self.dw_conv = nn.Conv1d( + config.hidden_size, + config.hidden_size * config.depthwise_multiplier, + config.kernel_size, + 1, + padding=padding, + groups=config.hidden_size, + ) + self.pw_conv = nn.Conv1d( + config.hidden_size * config.depthwise_multiplier, config.depthwise_seperable_out_channel, 1, 1, 0 + ) + + def forward(self, hidden_states): + return self.pw_conv(self.dw_conv(hidden_states)) + + +class Phi4MultimodalAudioGluPointWiseConv(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.output_dim = config.ext_pw_out_channel + + self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel * 2, kernel_size=1, stride=1) + self.glu_act = ACT2FN[config.conv_glu_type] + self.b1 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1)) + self.b2 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1)) + + def forward(self, hidden_states): + # we assume the input always has the #channel (#dim) in the last dimension of the + # tensor, so need to switch the dimension first for 1D-Conv case + hidden_states = hidden_states.permute([0, 2, 1]) + hidden_states = self.ext_pw_conv_1d(hidden_states) + out = hidden_states[:, 0 : self.output_dim, :] + self.b1 + out = out * self.glu_act(hidden_states[:, self.output_dim : self.output_dim * 2, :] + self.b2) + return out.permute([0, 2, 1]) + + +class Phi4MultimodalAudioConvModule(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.kernel_size = config.kernel_size + + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.glu = Phi4MultimodalAudioGluPointWiseConv(config) + self.dw_sep_conv_1d = Phi4MultimodalAudioDepthWiseSeperableConv1d(config, padding=config.kernel_size - 1) + self.act = ACT2FN[config.conv_activation] + self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel, kernel_size=1, stride=1) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.glu(self.layer_norm(hidden_states)) + hidden_states = self.dw_sep_conv_1d(hidden_states.permute([0, 2, 1])) + + if self.kernel_size > 1: + hidden_states = hidden_states[:, :, : -(self.kernel_size - 1)] + + hidden_states = self.act(hidden_states) + hidden_states = self.ext_pw_conv_1d(hidden_states) + out = self.dropout(hidden_states.permute([0, 2, 1])) + return out + + +class Phi4MultimodalAudioConformerEncoderLayer(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + + self.feed_forward_in = Phi4MultimodalAudioMLP(config) + self.self_attn = Phi4MultimodalAudioAttention(config) + self.conv = Phi4MultimodalAudioConvModule(config) + self.feed_forward_out = Phi4MultimodalAudioMLP(config) + self.layer_norm_att = nn.LayerNorm(config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + ): + residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states) + hidden_states = self.layer_norm_att(residual) + + hidden_states = residual + self.self_attn(hidden_states, attention_mask) + hidden_states = hidden_states + self.conv(hidden_states) + hidden_states = hidden_states + 0.5 * self.feed_forward_out(hidden_states) + + out = self.layer_norm(hidden_states) + + return out + + +class Phi4MultimodalAudioNemoConvSubsampling(torch.nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.subsampling_factor = config.time_reduction + self.sampling_num = int(math.log(self.subsampling_factor, 2)) + self.act_fn = ACT2FN[config.nemo_activation] + conv_channels = config.nemo_conv_channels + + layers = [ + nn.Conv2d(1, conv_channels, kernel_size=3, stride=2, padding=1), + self.act_fn, + ] + for _ in range(self.sampling_num - 1): + layers.extend( + [ + nn.Conv2d(conv_channels, conv_channels, kernel_size=3, stride=2, padding=1, groups=conv_channels), + nn.Conv2d(conv_channels, conv_channels, kernel_size=1, stride=1, padding=0, groups=1), + self.act_fn, + ] + ) + + # Aggregate the layers + self.conv = torch.nn.Sequential(*layers) + self.out = torch.nn.Linear(conv_channels * config.nemo_final_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]): + # Unsqueeze Channel Axis + hidden_states = hidden_states.unsqueeze(1) + hidden_states = self.conv(hidden_states) + + # Flatten Channel and Frequency Axes + b, _, t, _ = hidden_states.size() + hidden_states = self.out(hidden_states.transpose(1, 2).reshape(b, t, -1)) + + if mask is None: + return hidden_states, None + + max_audio_length = hidden_states.shape[1] + feature_lens = mask.sum(1) + padding_length = torch.ceil(feature_lens / self.subsampling_factor) + arange_ = torch.arange(0, max_audio_length, device=hidden_states.device) + pad_mask = arange_.expand(padding_length.size(0), -1) < padding_length.unsqueeze(1) + return hidden_states, pad_mask.unsqueeze(1) + + +class Phi4MultimodalAudioRelativeAttentionBias(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + + self.max_distance = config.bias_max_distance + self.symmetric = config.bias_symmetric + self.num_buckets = self.max_distance + if not config.bias_symmetric: + self.num_buckets *= 2 + self.bias_values = nn.Embedding(self.num_buckets, config.num_attention_heads) + + def forward(self, x): + # instantiate bias compatible with shape of x + max_pos = x.size(1) + context_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[:, None] + memory_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + # clipping to a maximum distance using ops that play well with ONNX export + relative_position = relative_position.masked_fill(relative_position < -self.max_distance, -self.max_distance) + relative_position = relative_position.masked_fill( + relative_position > self.max_distance - 1, self.max_distance - 1 + ) + + # mapping from relative position to index in the bias parameter + bias_idx = relative_position + bias_idx = bias_idx.abs() if self.symmetric else bias_idx + self.num_buckets // 2 + + att_bias = self.bias_values(bias_idx) + att_bias = att_bias.permute(2, 0, 1).unsqueeze(0) + + return att_bias + + +class Phi4MultimodalAudioMeanVarianceNormLayer(nn.Module): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.register_buffer("global_mean", torch.zeros(config.input_size)) + self.register_buffer("global_invstd", torch.ones(config.input_size)) + + def forward(self, x): + return (x - self.global_mean) * self.global_invstd + + +class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): + config_class = Phi4MultimodalAudioConfig + supports_gradient_checkpointing = True + _no_split_modules = ["Phi4MultimodalAudioConformerEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__(config) + self.config = config + + self.encoder_embedding = Phi4MultimodalAudioMeanVarianceNormLayer(config) + self.embed = Phi4MultimodalAudioNemoConvSubsampling(config) + self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias(config) + self.encoders = nn.ModuleList( + [Phi4MultimodalAudioConformerEncoderLayer(config) for _ in range(config.num_blocks)] + ) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): + # Create mask matrix for streaming + # S stores start index. if chunksize is 18, s is [0,18,36,....] + chunk_start_idx = np.arange(0, seq_len, chunk_size) + # avoid randomness when run evaluation or decoding + if self.training and np.random.rand() > 0.5: + # Either first or last chunk is not complete. + # If only the last one is not complete, EOS is not effective + chunk_start_idx = seq_len - chunk_start_idx + chunk_start_idx = chunk_start_idx[::-1] + chunk_start_idx = chunk_start_idx[:-1] + chunk_start_idx = np.insert(chunk_start_idx, 0, 0) + + enc_streaming_mask = ( + adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk) + .unsqueeze(0) + .expand([batch_size, -1, -1]) + ) + return enc_streaming_mask + + def forward_embeddings(self, hidden_states, masks): + """Forwarding the inputs through the top embedding layers""" + seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction) + if seq_len <= 0: + raise ValueError( + f"The squence length after time reduction is invalid: {seq_len}. Your input feature is too short." + ) + + batch_size = hidden_states.shape[0] + + enc_streaming_mask = self._streaming_mask(seq_len, batch_size, self.config.chunk_size, self.config.left_chunk) + enc_streaming_mask = enc_streaming_mask.to(hidden_states.device) + + hidden_states, masks = self.embed(hidden_states, masks) + + streaming_mask = enc_streaming_mask + if streaming_mask is not None and masks is not None: + hs_mask = masks & streaming_mask + elif masks is not None: + hs_mask = masks + else: + hs_mask = streaming_mask + + return hidden_states, hs_mask, masks + + def calculate_hs_mask(self, hidden_states, device, mask): + max_audio_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] + enc_streaming_mask = self._streaming_mask( + max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk + ) + enc_streaming_mask = enc_streaming_mask.to(device) + if mask is None: + return enc_streaming_mask + + feature_lens = mask.sum(1) + padding_length = feature_lens + pad_mask = torch.arange(0, max_audio_length, device=device).expand( + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) + pad_mask = pad_mask.unsqueeze(1) + pad_mask = pad_mask & enc_streaming_mask + return pad_mask + + def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]): + hidden_states = self.encoder_embedding(hidden_states) + hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask) + + unfolded = False + bs, seq_len, _ = hidden_states.shape + max_seq_len = 500 # maxium position for absolute positional encoding + if seq_len > max_seq_len: + # audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len + unfolded = True + # the unfold op will drop residual frames, pad it to the multiple of max_seq_len + if seq_len % max_seq_len > 0: + chunk_pad_size = max_seq_len - (seq_len % max_seq_len) + else: + chunk_pad_size = 0 + if chunk_pad_size > 0: + hidden_states_pad = F.pad(hidden_states, (0, 0, 0, chunk_pad_size), "constant", 0) + hidden_states = hidden_states_pad.to(hidden_states.device) + + hidden_states = unfold_tensor(hidden_states, max_seq_len) + masks_unfold = None + if mask is not None: + # revise hs_mask here because the previous calculated hs_mask did not consider extra pad + subsampled_pad_mask = mask.squeeze(1) # [bz, subsampled_unmask_seq_len] + extra_padded_subsamlped_pad_mask = F.pad( + subsampled_pad_mask, (0, chunk_pad_size), "constant", False + ) # extra padding to the pad mask + extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + masks_unfold = unfold_tensor( + extra_padded_subsamlped_pad_mask, max_seq_len + ) # unfold the pad mask like we did to the input tensor + masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor + hs_mask = self.calculate_hs_mask( + hidden_states, hidden_states.device, masks_unfold + ) # calculate hs_mask based on the unfolded pad mask + + relative_attention_bias = self.relative_attention_bias_layer(hidden_states) + attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias + + for layer in self.encoders: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + ) + else: + hidden_states = layer(hidden_states, attention_mask) + + if unfolded: + embed_dim = hidden_states.shape[-1] + hidden_states = hidden_states.reshape(bs, -1, embed_dim) + # if we ever padded before unfolding, we need to remove the padding + if chunk_pad_size > 0: + hidden_states = hidden_states[:, :-chunk_pad_size, :] + + return hidden_states + + +def unfold_tensor(tensor, max_seq_len): + """ + For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, + this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. + Args: + tensor: N, T, D + """ + _, _, D = tensor.shape + tensor = tensor.transpose(-1, -2) + # N x D x 1 x T => N x (D x max_seq_len) x T' + tensor = F.unfold(tensor[..., None, :], kernel_size=(1, max_seq_len), stride=(1, max_seq_len)) + + new_bsz, _, slen = tensor.shape + tensor = tensor.view(new_bsz, -1, max_seq_len, slen) + tensor = tensor.permute(0, 3, 2, 1) + tensor = tensor.view(-1, max_seq_len, D).contiguous() + return tensor + + +def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): + """ + The function is very important for Transformer Transducer Streaming mode + Args: + xs_len (int): sequence length + chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] + left_window (int): how many left chunks can be seen + right_window (int): how many right chunks can be seen. It is used for chunk overlap model. + Returns: + mask (torch.Tensor): a mask tensor for streaming model + """ + chunk_start_idx = torch.Tensor(chunk_start_idx).long() + start_pad = torch.nn.functional.pad( + chunk_start_idx, (1, 0) + ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] + end_pad = torch.nn.functional.pad( + chunk_start_idx, (0, 1), value=x_len + ) # append x_len to the end, so it becomes [0,18,36,48, x_len] + seq_range = torch.arange(0, x_len).unsqueeze(-1) + idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] + seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + idx_right = idx + right_window + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + boundary_right = end_pad[idx_right] + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + return mask_left & mask_right + + +class Phi4MultimodalAudioEmbedding(nn.Module): + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.config = config + self.layer_idx = config.audio_config.feature_layer + + self.drop = nn.Dropout(config.embd_pdrop) + self.encoder = Phi4MultimodalAudioModel._from_config(config.audio_config) + self.up_proj_for_speech = nn.Linear( + config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size + ) + self.down_proj_for_speech = nn.Linear(config.hidden_size, config.hidden_size) + self.up_proj_for_vision_speech = nn.Linear( + config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size + ) + self.down_proj_for_vision_speech = nn.Linear(config.hidden_size, config.hidden_size) + + def forward( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.Tensor, + audio_input_features: torch.FloatTensor, + audio_embed_sizes=None, + audio_attention_mask=None, + audio_projection_mode="speech", + ) -> torch.FloatTensor: + with torch.no_grad(): + positions_tuple = torch.nonzero(input_ids == self.config.audio_config.audio_token_id, as_tuple=True) + + up_proj = self.up_proj_for_speech if audio_projection_mode == "speech" else self.up_proj_for_vision_speech + down_proj = ( + self.down_proj_for_speech if audio_projection_mode == "speech" else self.down_proj_for_vision_speech + ) + + target_device = up_proj.bias.device + target_dtype = up_proj.bias.dtype + + audio_input_features = audio_input_features.to(device=target_device, dtype=target_dtype) + + audio_encoder_hidden_states = self.encoder(audio_input_features, audio_attention_mask) + audio_encoder_hidden_states = up_proj(audio_encoder_hidden_states) + audio_encoder_hidden_states = nn.functional.gelu(audio_encoder_hidden_states) + audio_embeds = down_proj(audio_encoder_hidden_states) + + merged_audio_embeds = torch.cat( + [audio_embeds[i, : audio_embed_sizes[i], :] for i in range(len(audio_embed_sizes))], dim=0 + ) + merged_audio_embeds = merged_audio_embeds.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) + # Temporarily disable autocast to avoid issue on bf16 tensors + # Ref: https://github.com/pytorch/pytorch/issues/132715 + with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): + audio_embeds = inputs_embeds.index_put( + indices=positions_tuple, values=merged_audio_embeds, accumulate=False + ) + + audio_embeds = self.drop(audio_embeds) + + return audio_embeds + + +#################################################### TEXT #################################################### + + +class Phi4MultimodalRMSNorm(Phi3RMSNorm): + pass + + +class Phi4MultimodalDecoderLayer(Phi3DecoderLayer): + pass + + +class Phi4MultimodalFeatureEmbedding(nn.Module): + """Image-audio embedding.""" + + def __init__(self, config: Phi4MultimodalConfig) -> None: + super().__init__() + self.config = config + self.image_token_id = config.vision_config.image_token_id + self.audio_token_id = config.audio_config.audio_token_id + self.image_embed = Phi4MultimodalImageEmbedding(config) + self.audio_embed = Phi4MultimodalAudioEmbedding(config) + + def forward( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.Tensor, + image_pixel_values: Optional[torch.FloatTensor] = None, + audio_input_features: Optional[torch.FloatTensor] = None, + image_sizes=None, + image_attention_mask=None, + audio_embed_sizes=None, + audio_attention_mask=None, + ) -> torch.FloatTensor: + with torch.no_grad(): + image_position_mask = (input_ids == self.config.vision_config.image_token_id).unsqueeze(-1) + non_image_position_mask = ~image_position_mask + + image_embeds = None + audio_embeds = None + if image_pixel_values is not None and (input_ids == self.image_token_id).any(): + image_embeds = self.image_embed( + input_ids, + inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + ) + if audio_input_features is not None and (input_ids == self.audio_token_id).any(): + audio_projection_mode = "vision" if image_pixel_values is not None else "speech" + audio_embeds = self.audio_embed( + input_ids, + inputs_embeds, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + audio_projection_mode=audio_projection_mode, + ) + + # merge image and audio + if image_embeds is not None and audio_embeds is not None: + inputs_embeds = image_embeds * image_position_mask + audio_embeds * non_image_position_mask + elif image_embeds is not None: + inputs_embeds = image_embeds + elif audio_embeds is not None: + inputs_embeds = audio_embeds + + return inputs_embeds + + +PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`)`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + See our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + image_pixel_values (`torch.FloatTensor`, *optional*): + If the input contains images, these correspond to the pixel values after transformations (as returned by + the Processor) + image_sizes (`torch.LongTensor`, *optional*): + If the input contains images, these correspond to size of each image. + image_attention_mask (`torch.LongTensor`, *optional*): + Attention mask for the images. + audio_input_features (`torch.FloatTensor`, *optional*): + If the input contains audio samples, these correspond to the values after transformation (as returned by + the Processor). + audio_embed_sizes (`torch.Tensor`, *optional*): + Size of the audio inputs. + audio_attention_mask (`torch.Tensor, *optional*): + Attention mask for the audio inputs. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +class Phi4MultimodalModel(Phi3Model, nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi4MultimodalMMDecoderLayer`] + Args: + config: Phi4MultimodalMMConfig + """ + + def __init__(self, config: Phi4MultimodalConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + + self.embed_tokens_extend = Phi4MultimodalFeatureEmbedding(config) + + self.layers = nn.ModuleList( + [Phi4MultimodalDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + image_pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + image_attention_mask=None, + audio_input_features: Optional[torch.FloatTensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens_extend( + input_ids, + inputs_embeds, + image_pixel_values=image_pixel_values, + audio_input_features=audio_input_features, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + +class Phi4MultimodalForCausalLM(Phi3ForCausalLM, nn.Module): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = Phi4MultimodalModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=Phi4MultimodalConfig) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + image_pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + image_attention_mask=None, + audio_input_features: Optional[torch.FloatTensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + Returns: + + Example: + ```python + >>> from transformers import AutoTokenizer, Phi4MultimodalForCausalLM + >>> model = Phi4MultimodalForCausalLM.from_pretrained("TBA") + >>> tokenizer = AutoTokenizer.from_pretrained("TBA") + >>> prompt = "This is an example script ." + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + image_pixel_values=None, + image_sizes=None, + image_attention_mask=None, + audio_input_features=None, + audio_embed_sizes=None, + audio_attention_mask=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=0, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + +__all__ = [ + "Phi4MultimodalAudioPreTrainedModel", + "Phi4MultimodalAudioModel", + "Phi4MultimodalVisionPreTrainedModel", + "Phi4MultimodalVisionModel", + "Phi4MultimodalPreTrainedModel", # noqa + "Phi4MultimodalModel", + "Phi4MultimodalForCausalLM", + "Phi4MultimodalVisionConfig", + "Phi4MultimodalAudioConfig", + "Phi4MultimodalConfig", +] diff --git a/src/transformers/models/phi4_multimodal/processing_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/processing_phi4_multimodal.py new file mode 100644 index 0000000000..d60275542f --- /dev/null +++ b/src/transformers/models/phi4_multimodal/processing_phi4_multimodal.py @@ -0,0 +1,194 @@ +# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Processor class for Phi4Multimodal +""" + +import re +from typing import List, Optional, Union + +from ...audio_utils import AudioInput +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import TextInput +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Phi4MultimodalProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "audio_kwargs": { + "device": "cpu", + }, + } + + +class Phi4MultimodalProcessor(ProcessorMixin): + r""" + Constructs a Phi4Multimodal processor which raps an image processor, a audio processor, and a GPT tokenizer into a single processor. + + [`Phi4MultimodalProcessor`] offers all the functionalities of [`Phi4MultimodalImageProcessorFast`] and [`GPT2Tokenizer`]. See the + [`~Phi4MultimodalProcessor.__call__`] and [`~Phi4MultimodalProcessor.decode`] for more information. + + Args: + image_processor (`Phi4MultimodalImageProcessorFast`): + The image processor to use for images. + audio_processor (`Phi4MultimodalFeatureExtractor`): + The audio processor to use for audio inputs. + tokenizer (`GPT2TokenizerFast`): + The tokenizer to use for text. + fake_image_token_pattern (`str`, *optional*, defaults to `r"<\|image_\d+\|>"`): + The fake image token pattern. + fake_audio_token_pattern (`str`, *optional*, defaults to `r"<\|audio_\d+\|>"`): + The fake audio token pattern. + """ + + attributes = ["image_processor", "audio_processor", "tokenizer"] + tokenizer_class = "GPT2TokenizerFast" + image_processor_class = "Phi4MultimodalImageProcessorFast" + audio_processor_class = "Phi4MultimodalFeatureExtractor" + valid_kwargs = ["chat_template", "fake_image_token_pattern", "fake_audio_token_pattern"] + + def __init__( + self, + image_processor, + audio_processor, + tokenizer, + fake_image_token_pattern: str = r"<\|image_\d+\|>", + fake_audio_token_pattern: str = r"<\|audio_\d+\|>", + **kwargs, + ): + super().__init__(image_processor, audio_processor, tokenizer, **kwargs) + self.fake_image_token_pattern = fake_image_token_pattern + self.fake_audio_token_pattern = fake_audio_token_pattern + + def __call__( + self, + text: Union[TextInput, List[TextInput]], + images: Optional[ImageInput] = None, + audios: Optional[AudioInput] = None, + **kwargs: Unpack[ProcessingKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forards the `text` + and `kwargs` arguments to GPT2Tokenizer's [`~GPT2Tokenizer.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + Phi4MultimodalImageProcessorFast's [`~Phi4MultimodalImageProcessorFast.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + audios (`List[Union[np.ndarray, torch.Tensor]]`): + List of the audios to be prepared. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **input_image_embeds** -- Pixel values to be fed to a model. + - **image_sizes** -- List of tuples specifying the size of each image in `input_image_embeds`. + - **image_attention_mask** -- List of attention masks for each image in `input_image_embeds`. + - **input_audio_embeds** -- Audio embeddings to be fed to a model. + - **audio_embed_sizes** -- List of integers specifying the size of each audio in `input_audio_embeds`. + """ + + output_kwargs = self._merge_kwargs(Phi4MultimodalProcessorKwargs, self.tokenizer.init_kwargs, **kwargs) + image_kwargs = output_kwargs["images_kwargs"] + audio_kwargs = output_kwargs["audio_kwargs"] + text_kwargs = output_kwargs["text_kwargs"] + + image_inputs = self.image_processor(images, **image_kwargs) if images is not None else {} + audio_inputs = self.audio_processor(audios, **audio_kwargs) if audios is not None else {} + + # We pop here for images as we don't need it later + num_img_tokens = image_inputs.pop("num_img_tokens", []) + audio_embed_sizes = audio_inputs.get("audio_embed_sizes", []) + + # Replace certain special tokens for compatibility + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + image_token = self.tokenizer.image_token + audio_token = self.tokenizer.audio_token + processed_text = [re.sub(self.fake_image_token_pattern, image_token, t) for t in text] + processed_text = [re.sub(self.fake_audio_token_pattern, audio_token, t) for t in processed_text] + + # Check that the number of special tokens is sound + concatenated_prompt = "".join(processed_text) + if concatenated_prompt.count(self.tokenizer.image_token) != len(num_img_tokens): + raise ValueError( + "You should add as much image tokens `<|image_i|>` in your prompt as you pass `images` to the processor" + ) + if concatenated_prompt.count(self.tokenizer.audio_token) != len(audio_embed_sizes): + raise ValueError( + "You should add as much audio tokens `<|audio_i|>` in your prompt as you pass `audios` to the processor" + ) + + # Add appropriate number of image/audio tokens (note that the count of replacement is dynamic) + image_count_iter = iter(num_img_tokens) + audio_count_iter = iter(audio_embed_sizes) + processed_text = [ + re.sub(re.escape(image_token), lambda _: image_token * next(image_count_iter), t) for t in processed_text + ] + processed_text = [ + re.sub(re.escape(audio_token), lambda _: audio_token * next(audio_count_iter), t) for t in processed_text + ] + + text_inputs = self.tokenizer(processed_text, **text_kwargs) + + # prepare batch feature + data = { + **text_inputs, + **image_inputs, + **audio_inputs, + } + + return BatchFeature(data=data) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + audio_processor_input_names = self.audio_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + audio_processor_input_names)) + + +__all__ = ["Phi4MultimodalProcessor"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 85eea3cb10..a7051cffca 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -7746,6 +7746,55 @@ class Phi3PreTrainedModel(metaclass=DummyObject): requires_backends(self, ["torch"]) +class Phi4MultimodalAudioModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Phi4MultimodalAudioPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Phi4MultimodalForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Phi4MultimodalModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Phi4MultimodalPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Phi4MultimodalVisionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Phi4MultimodalVisionPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class PhimoeForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py index 23a55f33b0..50314fc55e 100644 --- a/src/transformers/utils/dummy_torchvision_objects.py +++ b/src/transformers/utils/dummy_torchvision_objects.py @@ -93,6 +93,13 @@ class LlavaOnevisionImageProcessorFast(metaclass=DummyObject): requires_backends(self, ["torchvision"]) +class Phi4MultimodalImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + class PixtralImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] diff --git a/tests/models/phi4_multimodal/__init__.py b/tests/models/phi4_multimodal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py b/tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py new file mode 100644 index 0000000000..737e712a34 --- /dev/null +++ b/tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py @@ -0,0 +1,405 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import requests +from parameterized import parameterized + +from transformers import ( + AutoModelForCausalLM, + AutoProcessor, + GenerationConfig, + Phi4MultimodalAudioConfig, + Phi4MultimodalConfig, + Phi4MultimodalForCausalLM, + Phi4MultimodalModel, + Phi4MultimodalVisionConfig, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import ( + require_soundfile, + require_torch, + slow, + torch_device, +) +from transformers.utils import is_soundfile_available + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + + +if is_vision_available(): + from PIL import Image + + +if is_soundfile_available(): + import soundfile + + +class Phi4MultimodalModelTester: + def __init__( + self, + parent, + batch_size=2, + seq_length=12, + image_seq_length=275, + audio_seq_length=8, + is_training=True, + num_hidden_layers=2, + vocab_size=49, + hidden_size=32, + intermediate_size=64, + num_attention_heads=8, + num_key_value_heads=4, + bos_token_id=0, + eos_token_id=0, + pad_token_id=0, + image_token_id=1, + audio_token_id=2, + image_size=16, + audio_size=12, + audio_config=Phi4MultimodalAudioConfig( + num_blocks=2, + hidden_size=32, + num_attention_heads=8, + intermediate_size=48, + depthwise_seperable_out_channel=128, + nemo_conv_channels=128, + ), + vision_config=Phi4MultimodalVisionConfig( + num_hidden_layers=2, + hidden_size=32, + intermediate_size=64, + num_attention_heads=8, + crop_size=16, + ), + ): + self.parent = parent + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.eos_token_id = eos_token_id + self.image_token_id = image_token_id + self.audio_token_id = audio_token_id + self.audio_config = audio_config + self.vision_config = vision_config + + self.is_training = is_training + self.batch_size = batch_size + self.seq_length = seq_length + image_seq_length + audio_seq_length + self.image_seq_length = image_seq_length + self.audio_seq_length = audio_seq_length + self.image_size = image_size + self.audio_size = audio_size + self.num_channels = 3 + + def get_config(self): + return Phi4MultimodalConfig( + num_hidden_layers=self.num_hidden_layers, + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + vision_config=self.vision_config, + audio_config=self.audio_config, + ) + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + # The shapes corresponds to the inputs for image of size 16x16 + image_pixel_values = floats_tensor([self.batch_size, 2, self.num_channels, self.image_size, self.image_size]) + image_attention_mask = torch.ones(self.batch_size, 2, 1, 1) + image_sizes = torch.tensor( + [[self.image_size, self.image_size]] * self.batch_size, dtype=torch.long, device=torch_device + ) + + # Feature sizes returned by an audio of size 10000 + audio_input_features = floats_tensor([self.batch_size, 61, 80]) + audio_embed_sizes = torch.tensor([self.audio_seq_length] * self.batch_size, dtype=torch.long) + + input_ids[input_ids == self.pad_token_id] = self.pad_token_id + 1 # random value but not pad token + input_ids[-1, 0] = self.pad_token_id # mask the last text token + input_ids[:, -self.image_seq_length - self.audio_seq_length : -self.audio_seq_length] = self.image_token_id + input_ids[:, -self.audio_seq_length :] = self.audio_token_id + + attention_mask = torch.ones_like(input_ids) + attention_mask[-1, 0] = 0 # mask the last text token + config = self.get_config() + + return ( + config, + input_ids, + attention_mask, + image_pixel_values, + image_attention_mask, + image_sizes, + audio_input_features, + audio_embed_sizes, + ) + + def prepare_config_and_inputs_for_common(self): + ( + config, + input_ids, + attention_mask, + image_pixel_values, + image_attention_mask, + image_sizes, + audio_input_features, + audio_embed_sizes, + ) = self.prepare_config_and_inputs() + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "image_pixel_values": image_pixel_values, + "image_attention_mask": image_attention_mask, + "image_sizes": image_sizes, + "audio_input_features": audio_input_features, + "audio_embed_sizes": audio_embed_sizes, + } + return config, inputs_dict + + def create_and_check_model(self, config, input_ids, attention_mask): + model = Phi4MultimodalForCausalLM(config=config) + model.to(torch_device) + model.eval() + with torch.autocast(device_type="cuda", dtype=torch.float16): + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + )["logits"] + self.parent.assertEqual(logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + self.parent.assertFalse(torch.isnan(logits).any().item()) + + +@require_torch +class Phi4MultimodalModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + """ + Model tester for `Phi4Multimodal`. + """ + + all_model_classes = (Phi4MultimodalForCausalLM, Phi4MultimodalModel) if is_torch_available() else () + test_pruning = False + test_head_masking = False + _is_composite = True + + def setUp(self): + self.model_tester = Phi4MultimodalModelTester(self) + self.config_tester = ConfigTester(self, config_class=Phi4MultimodalConfig) + + @unittest.skip(reason="Unstable test") + def test_initialization(self): + pass + + @unittest.skip(reason="Right padding not supported") + def test_flash_attn_2_inference_equivalence_right_padding(self): + pass + + @unittest.skip(reason="This one tries to use right padding as well") + def test_eager_matches_fa2_generate(self): + pass + + @unittest.skip(reason="Depending on input modalities, some params may not have gradients") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="Depending on input modalities, some params may not have gradients") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="Depending on input modalities, some params may not have gradients") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="Test tries to instantiate dynamic cache with an arg") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip(reason="Test is only for old attention format") + def test_sdpa_can_dispatch_composite_models(self): + pass + + @unittest.skip(reason="Static cache supported only for text-only inputs (not images or audios)") + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + + @unittest.skip(reason="Static cache supported only for text-only inputs (not images or audios)") + def test_generate_with_static_cache(self): + pass + + @unittest.skip( + reason="Supported only for text-only inputs (otherwise dynamic control flows for multimodal inputs)" + ) + def test_generate_compilation_all_outputs(self): + pass + + @unittest.skip( + reason="Supported only for text-only inputs (otherwise dynamic control flows for multimodal inputs)" + ) + def test_generate_compile_model_forward(self): + pass + + @parameterized.expand([("random",), ("same",)]) + @unittest.skip(reason="`image_attention_mask` has a specific shape") + def test_assisted_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip(reason="`image_attention_mask` has a specific shape") + def test_assisted_decoding_sample(self): + pass + + @unittest.skip(reason="`image_attention_mask` has a specific shape") + def test_prompt_lookup_decoding_matches_greedy_search(self): + pass + + @unittest.skip(reason="Cannot unpad inputs for all modalities so easily") + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip(reason="Dynamo error") + def test_flex_attention_with_grads(self): + pass + + +@require_torch +@slow +class Phi4MultimodalIntegrationTest(unittest.TestCase): + checkpoint_path = "microsoft/Phi-4-multimodal-instruct" + image_url = "https://www.ilankelman.org/stopsigns/australia.jpg" + audio_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav" + + def setUp(self): + self.processor = AutoProcessor.from_pretrained(self.checkpoint_path) + self.generation_config = GenerationConfig(max_new_tokens=20, do_sample=False) + self.user_token = "<|user|>" + self.assistant_token = "<|assistant|>" + self.end_token = "<|end|>" + self.image = Image.open(requests.get(self.image_url, stream=True).raw) + with tempfile.NamedTemporaryFile(mode="w+b", suffix=".wav") as tmp: + tmp.write(requests.get(self.audio_url, stream=True).raw.data) + tmp.flush() + tmp.seek(0) + self.audio, self.sampling_rate = soundfile.read(tmp.name) + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def test_text_only_generation(self): + model = AutoModelForCausalLM.from_pretrained( + self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device + ) + + prompt = f"{self.user_token}What is the answer for 1+1? Explain it.{self.end_token}{self.assistant_token}" + inputs = self.processor(prompt, images=None, return_tensors="pt").to(torch_device) + + output = model.generate( + **inputs, + generation_config=self.generation_config, + ) + output = output[:, inputs["input_ids"].shape[1] :] + response = self.processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + + EXPECTED_RESPONSE = "The answer for 1+1 is 2. This is because when you add one to another" + + self.assertEqual(response, EXPECTED_RESPONSE) + + def test_vision_text_generation(self): + model = AutoModelForCausalLM.from_pretrained( + self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device + ) + + prompt = f"{self.user_token}<|image_1|>What is shown in this image?{self.end_token}{self.assistant_token}" + inputs = self.processor(prompt, images=self.image, return_tensors="pt").to(torch_device) + + output = model.generate( + **inputs, + generation_config=self.generation_config, + ) + output = output[:, inputs["input_ids"].shape[1] :] + response = self.processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + + EXPECTED_RESPONSE = "The image shows a vibrant scene at a street intersection in a city with a Chinese-influenced architectural" + + self.assertEqual(response, EXPECTED_RESPONSE) + + def test_multi_image_vision_text_generation(self): + model = AutoModelForCausalLM.from_pretrained( + self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device + ) + + images = [] + placeholder = "" + for i in range(1, 5): + url = f"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-{i}-2048.jpg" + images.append(Image.open(requests.get(url, stream=True).raw)) + placeholder += f"<|image_{i}|>" + + prompt = f"{self.user_token}{placeholder}Summarize the deck of slides.{self.end_token}{self.assistant_token}" + inputs = self.processor(prompt, images, return_tensors="pt").to(torch_device) + + output = model.generate( + **inputs, + generation_config=self.generation_config, + ) + output = output[:, inputs["input_ids"].shape[1] :] + response = self.processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + + EXPECTED_RESPONSE = "The presentation provides an overview of Microsoft Azure, a cloud computing platform by Microsoft, and its various services" + + self.assertEqual(response, EXPECTED_RESPONSE) + + @require_soundfile + def test_audio_text_generation(self): + model = AutoModelForCausalLM.from_pretrained( + self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device + ) + + prompt = f"{self.user_token}<|audio_1|>What is happening in this audio?{self.end_token}{self.assistant_token}" + inputs = self.processor(prompt, audios=self.audio, sampling_rate=self.sampling_rate, return_tensors="pt").to( + torch_device + ) + + output = model.generate( + **inputs, + generation_config=self.generation_config, + ) + output = output[:, inputs["input_ids"].shape[1] :] + response = self.processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + + # Yes, it is truly the expected response... Even though the model correctly treats the audio file + EXPECTED_RESPONSE = "I'm sorry, but I can't listen to audio. However, if you describe the audio to me," + + self.assertEqual(response, EXPECTED_RESPONSE) diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 81ab0dea0d..5721e5913c 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -524,6 +524,7 @@ OBJECTS_TO_IGNORE = [ "TimeSeriesTransformerConfig", "TokenClassificationPipeline", "TrOCRConfig", + "Phi4MultimodalProcessor", "TrainerState", "TrainingArguments", "TrajectoryTransformerConfig", diff --git a/utils/check_repo.py b/utils/check_repo.py index 54bb9267c5..488754d2e8 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -89,6 +89,8 @@ PRIVATE_MODELS = [ "SmolVLMVisionTransformer", "AriaTextForCausalLM", "AriaTextModel", + "Phi4MultimodalAudioModel", + "Phi4MultimodalVisionModel", ] # Update this list for models that are not tested with a comment explaining the reason it should not be.