[Phi4] add multimodal chat template (#36996)

* phi4 chat template

* remove from valid kwargs
This commit is contained in:
Raushan Turganbay
2025-04-03 09:52:09 +02:00
committed by GitHub
parent c9302c0983
commit 98601cc818
3 changed files with 49 additions and 46 deletions

View File

@@ -30,14 +30,8 @@ found [here](https://github.com/huggingface/transformers/blob/main/src/transform
In the following, we demonstrate how to use it for inference depending on the input modalities (text, image, audio).
```python
import requests
import torch
import os
import io
from PIL import Image
import soundfile as sf
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from urllib.request import urlopen
# Define model path
@@ -52,21 +46,25 @@ model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, tor
model.load_adapter(model_path, adapter_name="speech", device_map=device, adapter_kwargs={"subfolder": 'speech-lora'})
model.load_adapter(model_path, adapter_name="vision", device_map=device, adapter_kwargs={"subfolder": 'vision-lora'})
# Define prompt structure
user_prompt = '<|user|>'
assistant_prompt = '<|assistant|>'
prompt_suffix = '<|end|>'
# Part : Image Processing
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
# Part 1: Image Processing
model.set_adapter("vision") # if loaded, activate the vision adapter
print("\n--- IMAGE PROCESSING ---")
image_url = 'https://www.ilankelman.org/stopsigns/australia.jpg'
prompt = f'{user_prompt}<|image_1|>What is shown in this image?{prompt_suffix}{assistant_prompt}'
print(f'>>> Prompt\n{prompt}')
# Download and open image
image = Image.open(requests.get(image_url, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors='pt').to(device)
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(device, torch.float16)
# Generate response
generate_ids = model.generate(
@@ -80,19 +78,28 @@ response = processor.batch_decode(
)[0]
print(f'>>> Response\n{response}')
# Part 2: Audio Processing
model.set_adapter("speech") # if loaded, activate the speech adapter
print("\n--- AUDIO PROCESSING ---")
audio_url = "https://upload.wikimedia.org/wikipedia/commons/b/b0/Barbara_Sahakian_BBC_Radio4_The_Life_Scientific_29_May_2012_b01j5j24.flac"
speech_prompt = "Transcribe the audio to text, and then translate the audio to French. Use <sep> as a separator between the original transcript and the translation."
prompt = f'{user_prompt}<|audio_1|>{speech_prompt}{prompt_suffix}{assistant_prompt}'
print(f'>>> Prompt\n{prompt}')
messages = [
{
"role": "user",
"content": [
{"type": "audio", "url": audio_url},
{"type": "text", "text": "Transcribe the audio to text, and then translate the audio to French. Use <sep> as a separator between the origina transcript and the translation."},
],
},
]
# Downlowd and open audio file
audio, sample_rate = sf.read(io.BytesIO(urlopen(audio_url).read()))
# Process with the model
inputs = processor(text=prompt, audios=audio, sample_rate=sample_rate, return_tensors='pt').to(device)
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
sample_rate=sample_rate,
).to(device, torch.float16)
generate_ids = model.generate(
**inputs,

View File

@@ -29,7 +29,7 @@ from ...image_processing_utils_fast import (
Unpack,
convert_to_rgb,
)
from ...image_utils import ImageInput, make_list_of_images, valid_images
from ...image_utils import ImageInput, make_flat_list_of_images, valid_images
from ...utils import TensorType, logging
@@ -175,7 +175,7 @@ class Phi4MultimodalImageProcessorFast(BaseImageProcessorFast):
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
images = make_list_of_images(images)
images = make_flat_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "

View File

@@ -62,26 +62,22 @@ class Phi4MultimodalProcessor(ProcessorMixin):
tokenizer_class = "GPT2TokenizerFast"
image_processor_class = "Phi4MultimodalImageProcessorFast"
audio_processor_class = "Phi4MultimodalFeatureExtractor"
valid_kwargs = ["chat_template", "fake_image_token_pattern", "fake_audio_token_pattern"]
valid_kwargs = ["chat_template"]
def __init__(
self,
image_processor,
audio_processor,
tokenizer,
fake_image_token_pattern: str = r"<\|image_\d+\|>",
fake_audio_token_pattern: str = r"<\|audio_\d+\|>",
**kwargs,
):
super().__init__(image_processor, audio_processor, tokenizer, **kwargs)
self.fake_image_token_pattern = fake_image_token_pattern
self.fake_audio_token_pattern = fake_audio_token_pattern
def __call__(
self,
text: Union[TextInput, List[TextInput]],
images: Optional[ImageInput] = None,
audios: Optional[AudioInput] = None,
audio: Optional[AudioInput] = None,
**kwargs: Unpack[ProcessingKwargs],
) -> BatchFeature:
"""
@@ -99,7 +95,7 @@ class Phi4MultimodalProcessor(ProcessorMixin):
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
audios (`List[Union[np.ndarray, torch.Tensor]]`):
audio (`List[Union[np.ndarray, torch.Tensor]]`):
List of the audios to be prepared.
Returns:
@@ -120,7 +116,7 @@ class Phi4MultimodalProcessor(ProcessorMixin):
text_kwargs = output_kwargs["text_kwargs"]
image_inputs = self.image_processor(images, **image_kwargs) if images is not None else {}
audio_inputs = self.audio_processor(audios, **audio_kwargs) if audios is not None else {}
audio_inputs = self.audio_processor(audio, **audio_kwargs) if audio is not None else {}
# We pop here for images as we don't need it later
num_img_tokens = image_inputs.pop("num_img_tokens", [])
@@ -134,25 +130,25 @@ class Phi4MultimodalProcessor(ProcessorMixin):
image_token = self.tokenizer.image_token
audio_token = self.tokenizer.audio_token
processed_text = [re.sub(self.fake_image_token_pattern, image_token, t) for t in text]
processed_text = [re.sub(self.fake_audio_token_pattern, audio_token, t) for t in processed_text]
# Check that the number of special tokens is sound
concatenated_prompt = "".join(processed_text)
if concatenated_prompt.count(self.tokenizer.image_token) != len(num_img_tokens):
concatenated_prompt = "".join(text)
if concatenated_prompt.count(image_token) != len(num_img_tokens):
raise ValueError(
"You should add as much image tokens `<|image_i|>` in your prompt as you pass `images` to the processor"
"You should add as much image tokens `<|image|>` in your prompt as you pass `images` to the processor. ",
f"Input contains {concatenated_prompt.count(image_token)} tokens != {len(num_img_tokens)} images",
)
if concatenated_prompt.count(self.tokenizer.audio_token) != len(audio_embed_sizes):
if concatenated_prompt.count(audio_token) != len(audio_embed_sizes):
raise ValueError(
"You should add as much audio tokens `<|audio_i|>` in your prompt as you pass `audios` to the processor"
"You should add as much audio tokens `<|audio|>` in your prompt as you pass `audios` to the processor. "
f"Input contains {concatenated_prompt.count(audio_token)} tokens != {len(audio_embed_sizes)} audios"
)
# Add appropriate number of image/audio tokens (note that the count of replacement is dynamic)
image_count_iter = iter(num_img_tokens)
audio_count_iter = iter(audio_embed_sizes)
processed_text = [
re.sub(re.escape(image_token), lambda _: image_token * next(image_count_iter), t) for t in processed_text
re.sub(re.escape(image_token), lambda _: image_token * next(image_count_iter), t) for t in text
]
processed_text = [
re.sub(re.escape(audio_token), lambda _: audio_token * next(audio_count_iter), t) for t in processed_text