[Phi4] add multimodal chat template (#36996)

* phi4 chat template

* remove from valid kwargs
This commit is contained in:
Raushan Turganbay
2025-04-03 09:52:09 +02:00
committed by GitHub
parent c9302c0983
commit 98601cc818
3 changed files with 49 additions and 46 deletions

View File

@@ -30,14 +30,8 @@ found [here](https://github.com/huggingface/transformers/blob/main/src/transform
In the following, we demonstrate how to use it for inference depending on the input modalities (text, image, audio). In the following, we demonstrate how to use it for inference depending on the input modalities (text, image, audio).
```python ```python
import requests
import torch import torch
import os
import io
from PIL import Image
import soundfile as sf
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from urllib.request import urlopen
# Define model path # Define model path
@@ -52,21 +46,25 @@ model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, tor
model.load_adapter(model_path, adapter_name="speech", device_map=device, adapter_kwargs={"subfolder": 'speech-lora'}) model.load_adapter(model_path, adapter_name="speech", device_map=device, adapter_kwargs={"subfolder": 'speech-lora'})
model.load_adapter(model_path, adapter_name="vision", device_map=device, adapter_kwargs={"subfolder": 'vision-lora'}) model.load_adapter(model_path, adapter_name="vision", device_map=device, adapter_kwargs={"subfolder": 'vision-lora'})
# Define prompt structure # Part : Image Processing
user_prompt = '<|user|>' messages = [
assistant_prompt = '<|assistant|>' {
prompt_suffix = '<|end|>' "role": "user",
"content": [
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
# Part 1: Image Processing
model.set_adapter("vision") # if loaded, activate the vision adapter model.set_adapter("vision") # if loaded, activate the vision adapter
print("\n--- IMAGE PROCESSING ---") inputs = processor.apply_chat_template(
image_url = 'https://www.ilankelman.org/stopsigns/australia.jpg' messages,
prompt = f'{user_prompt}<|image_1|>What is shown in this image?{prompt_suffix}{assistant_prompt}' add_generation_prompt=True,
print(f'>>> Prompt\n{prompt}') tokenize=True,
return_dict=True,
# Download and open image return_tensors="pt",
image = Image.open(requests.get(image_url, stream=True).raw) ).to(device, torch.float16)
inputs = processor(text=prompt, images=image, return_tensors='pt').to(device)
# Generate response # Generate response
generate_ids = model.generate( generate_ids = model.generate(
@@ -80,19 +78,28 @@ response = processor.batch_decode(
)[0] )[0]
print(f'>>> Response\n{response}') print(f'>>> Response\n{response}')
# Part 2: Audio Processing # Part 2: Audio Processing
model.set_adapter("speech") # if loaded, activate the speech adapter model.set_adapter("speech") # if loaded, activate the speech adapter
print("\n--- AUDIO PROCESSING ---")
audio_url = "https://upload.wikimedia.org/wikipedia/commons/b/b0/Barbara_Sahakian_BBC_Radio4_The_Life_Scientific_29_May_2012_b01j5j24.flac" audio_url = "https://upload.wikimedia.org/wikipedia/commons/b/b0/Barbara_Sahakian_BBC_Radio4_The_Life_Scientific_29_May_2012_b01j5j24.flac"
speech_prompt = "Transcribe the audio to text, and then translate the audio to French. Use <sep> as a separator between the original transcript and the translation." messages = [
prompt = f'{user_prompt}<|audio_1|>{speech_prompt}{prompt_suffix}{assistant_prompt}' {
print(f'>>> Prompt\n{prompt}') "role": "user",
"content": [
{"type": "audio", "url": audio_url},
{"type": "text", "text": "Transcribe the audio to text, and then translate the audio to French. Use <sep> as a separator between the origina transcript and the translation."},
],
},
]
# Downlowd and open audio file inputs = processor.apply_chat_template(
audio, sample_rate = sf.read(io.BytesIO(urlopen(audio_url).read())) messages,
add_generation_prompt=True,
# Process with the model tokenize=True,
inputs = processor(text=prompt, audios=audio, sample_rate=sample_rate, return_tensors='pt').to(device) return_dict=True,
return_tensors="pt",
sample_rate=sample_rate,
).to(device, torch.float16)
generate_ids = model.generate( generate_ids = model.generate(
**inputs, **inputs,

View File

@@ -29,7 +29,7 @@ from ...image_processing_utils_fast import (
Unpack, Unpack,
convert_to_rgb, convert_to_rgb,
) )
from ...image_utils import ImageInput, make_list_of_images, valid_images from ...image_utils import ImageInput, make_flat_list_of_images, valid_images
from ...utils import TensorType, logging from ...utils import TensorType, logging
@@ -175,7 +175,7 @@ class Phi4MultimodalImageProcessorFast(BaseImageProcessorFast):
image_mean = image_mean if image_mean is not None else self.image_mean image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std image_std = image_std if image_std is not None else self.image_std
images = make_list_of_images(images) images = make_flat_list_of_images(images)
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "

View File

@@ -62,26 +62,22 @@ class Phi4MultimodalProcessor(ProcessorMixin):
tokenizer_class = "GPT2TokenizerFast" tokenizer_class = "GPT2TokenizerFast"
image_processor_class = "Phi4MultimodalImageProcessorFast" image_processor_class = "Phi4MultimodalImageProcessorFast"
audio_processor_class = "Phi4MultimodalFeatureExtractor" audio_processor_class = "Phi4MultimodalFeatureExtractor"
valid_kwargs = ["chat_template", "fake_image_token_pattern", "fake_audio_token_pattern"] valid_kwargs = ["chat_template"]
def __init__( def __init__(
self, self,
image_processor, image_processor,
audio_processor, audio_processor,
tokenizer, tokenizer,
fake_image_token_pattern: str = r"<\|image_\d+\|>",
fake_audio_token_pattern: str = r"<\|audio_\d+\|>",
**kwargs, **kwargs,
): ):
super().__init__(image_processor, audio_processor, tokenizer, **kwargs) super().__init__(image_processor, audio_processor, tokenizer, **kwargs)
self.fake_image_token_pattern = fake_image_token_pattern
self.fake_audio_token_pattern = fake_audio_token_pattern
def __call__( def __call__(
self, self,
text: Union[TextInput, List[TextInput]], text: Union[TextInput, List[TextInput]],
images: Optional[ImageInput] = None, images: Optional[ImageInput] = None,
audios: Optional[AudioInput] = None, audio: Optional[AudioInput] = None,
**kwargs: Unpack[ProcessingKwargs], **kwargs: Unpack[ProcessingKwargs],
) -> BatchFeature: ) -> BatchFeature:
""" """
@@ -99,7 +95,7 @@ class Phi4MultimodalProcessor(ProcessorMixin):
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported. tensor. Both channels-first and channels-last formats are supported.
audios (`List[Union[np.ndarray, torch.Tensor]]`): audio (`List[Union[np.ndarray, torch.Tensor]]`):
List of the audios to be prepared. List of the audios to be prepared.
Returns: Returns:
@@ -120,7 +116,7 @@ class Phi4MultimodalProcessor(ProcessorMixin):
text_kwargs = output_kwargs["text_kwargs"] text_kwargs = output_kwargs["text_kwargs"]
image_inputs = self.image_processor(images, **image_kwargs) if images is not None else {} image_inputs = self.image_processor(images, **image_kwargs) if images is not None else {}
audio_inputs = self.audio_processor(audios, **audio_kwargs) if audios is not None else {} audio_inputs = self.audio_processor(audio, **audio_kwargs) if audio is not None else {}
# We pop here for images as we don't need it later # We pop here for images as we don't need it later
num_img_tokens = image_inputs.pop("num_img_tokens", []) num_img_tokens = image_inputs.pop("num_img_tokens", [])
@@ -134,25 +130,25 @@ class Phi4MultimodalProcessor(ProcessorMixin):
image_token = self.tokenizer.image_token image_token = self.tokenizer.image_token
audio_token = self.tokenizer.audio_token audio_token = self.tokenizer.audio_token
processed_text = [re.sub(self.fake_image_token_pattern, image_token, t) for t in text]
processed_text = [re.sub(self.fake_audio_token_pattern, audio_token, t) for t in processed_text]
# Check that the number of special tokens is sound # Check that the number of special tokens is sound
concatenated_prompt = "".join(processed_text) concatenated_prompt = "".join(text)
if concatenated_prompt.count(self.tokenizer.image_token) != len(num_img_tokens): if concatenated_prompt.count(image_token) != len(num_img_tokens):
raise ValueError( raise ValueError(
"You should add as much image tokens `<|image_i|>` in your prompt as you pass `images` to the processor" "You should add as much image tokens `<|image|>` in your prompt as you pass `images` to the processor. ",
f"Input contains {concatenated_prompt.count(image_token)} tokens != {len(num_img_tokens)} images",
) )
if concatenated_prompt.count(self.tokenizer.audio_token) != len(audio_embed_sizes): if concatenated_prompt.count(audio_token) != len(audio_embed_sizes):
raise ValueError( raise ValueError(
"You should add as much audio tokens `<|audio_i|>` in your prompt as you pass `audios` to the processor" "You should add as much audio tokens `<|audio|>` in your prompt as you pass `audios` to the processor. "
f"Input contains {concatenated_prompt.count(audio_token)} tokens != {len(audio_embed_sizes)} audios"
) )
# Add appropriate number of image/audio tokens (note that the count of replacement is dynamic) # Add appropriate number of image/audio tokens (note that the count of replacement is dynamic)
image_count_iter = iter(num_img_tokens) image_count_iter = iter(num_img_tokens)
audio_count_iter = iter(audio_embed_sizes) audio_count_iter = iter(audio_embed_sizes)
processed_text = [ processed_text = [
re.sub(re.escape(image_token), lambda _: image_token * next(image_count_iter), t) for t in processed_text re.sub(re.escape(image_token), lambda _: image_token * next(image_count_iter), t) for t in text
] ]
processed_text = [ processed_text = [
re.sub(re.escape(audio_token), lambda _: audio_token * next(audio_count_iter), t) for t in processed_text re.sub(re.escape(audio_token), lambda _: audio_token * next(audio_count_iter), t) for t in processed_text