[Phi4] add multimodal chat template (#36996)
* phi4 chat template * remove from valid kwargs
This commit is contained in:
committed by
GitHub
parent
c9302c0983
commit
98601cc818
@@ -30,14 +30,8 @@ found [here](https://github.com/huggingface/transformers/blob/main/src/transform
|
|||||||
In the following, we demonstrate how to use it for inference depending on the input modalities (text, image, audio).
|
In the following, we demonstrate how to use it for inference depending on the input modalities (text, image, audio).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import requests
|
|
||||||
import torch
|
import torch
|
||||||
import os
|
|
||||||
import io
|
|
||||||
from PIL import Image
|
|
||||||
import soundfile as sf
|
|
||||||
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
|
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
|
||||||
from urllib.request import urlopen
|
|
||||||
|
|
||||||
|
|
||||||
# Define model path
|
# Define model path
|
||||||
@@ -52,21 +46,25 @@ model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, tor
|
|||||||
model.load_adapter(model_path, adapter_name="speech", device_map=device, adapter_kwargs={"subfolder": 'speech-lora'})
|
model.load_adapter(model_path, adapter_name="speech", device_map=device, adapter_kwargs={"subfolder": 'speech-lora'})
|
||||||
model.load_adapter(model_path, adapter_name="vision", device_map=device, adapter_kwargs={"subfolder": 'vision-lora'})
|
model.load_adapter(model_path, adapter_name="vision", device_map=device, adapter_kwargs={"subfolder": 'vision-lora'})
|
||||||
|
|
||||||
# Define prompt structure
|
# Part : Image Processing
|
||||||
user_prompt = '<|user|>'
|
messages = [
|
||||||
assistant_prompt = '<|assistant|>'
|
{
|
||||||
prompt_suffix = '<|end|>'
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
# Part 1: Image Processing
|
|
||||||
model.set_adapter("vision") # if loaded, activate the vision adapter
|
model.set_adapter("vision") # if loaded, activate the vision adapter
|
||||||
print("\n--- IMAGE PROCESSING ---")
|
inputs = processor.apply_chat_template(
|
||||||
image_url = 'https://www.ilankelman.org/stopsigns/australia.jpg'
|
messages,
|
||||||
prompt = f'{user_prompt}<|image_1|>What is shown in this image?{prompt_suffix}{assistant_prompt}'
|
add_generation_prompt=True,
|
||||||
print(f'>>> Prompt\n{prompt}')
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
# Download and open image
|
return_tensors="pt",
|
||||||
image = Image.open(requests.get(image_url, stream=True).raw)
|
).to(device, torch.float16)
|
||||||
inputs = processor(text=prompt, images=image, return_tensors='pt').to(device)
|
|
||||||
|
|
||||||
# Generate response
|
# Generate response
|
||||||
generate_ids = model.generate(
|
generate_ids = model.generate(
|
||||||
@@ -80,19 +78,28 @@ response = processor.batch_decode(
|
|||||||
)[0]
|
)[0]
|
||||||
print(f'>>> Response\n{response}')
|
print(f'>>> Response\n{response}')
|
||||||
|
|
||||||
|
|
||||||
# Part 2: Audio Processing
|
# Part 2: Audio Processing
|
||||||
model.set_adapter("speech") # if loaded, activate the speech adapter
|
model.set_adapter("speech") # if loaded, activate the speech adapter
|
||||||
print("\n--- AUDIO PROCESSING ---")
|
|
||||||
audio_url = "https://upload.wikimedia.org/wikipedia/commons/b/b0/Barbara_Sahakian_BBC_Radio4_The_Life_Scientific_29_May_2012_b01j5j24.flac"
|
audio_url = "https://upload.wikimedia.org/wikipedia/commons/b/b0/Barbara_Sahakian_BBC_Radio4_The_Life_Scientific_29_May_2012_b01j5j24.flac"
|
||||||
speech_prompt = "Transcribe the audio to text, and then translate the audio to French. Use <sep> as a separator between the original transcript and the translation."
|
messages = [
|
||||||
prompt = f'{user_prompt}<|audio_1|>{speech_prompt}{prompt_suffix}{assistant_prompt}'
|
{
|
||||||
print(f'>>> Prompt\n{prompt}')
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "audio", "url": audio_url},
|
||||||
|
{"type": "text", "text": "Transcribe the audio to text, and then translate the audio to French. Use <sep> as a separator between the origina transcript and the translation."},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
# Downlowd and open audio file
|
inputs = processor.apply_chat_template(
|
||||||
audio, sample_rate = sf.read(io.BytesIO(urlopen(audio_url).read()))
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
# Process with the model
|
tokenize=True,
|
||||||
inputs = processor(text=prompt, audios=audio, sample_rate=sample_rate, return_tensors='pt').to(device)
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
).to(device, torch.float16)
|
||||||
|
|
||||||
generate_ids = model.generate(
|
generate_ids = model.generate(
|
||||||
**inputs,
|
**inputs,
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from ...image_processing_utils_fast import (
|
|||||||
Unpack,
|
Unpack,
|
||||||
convert_to_rgb,
|
convert_to_rgb,
|
||||||
)
|
)
|
||||||
from ...image_utils import ImageInput, make_list_of_images, valid_images
|
from ...image_utils import ImageInput, make_flat_list_of_images, valid_images
|
||||||
from ...utils import TensorType, logging
|
from ...utils import TensorType, logging
|
||||||
|
|
||||||
|
|
||||||
@@ -175,7 +175,7 @@ class Phi4MultimodalImageProcessorFast(BaseImageProcessorFast):
|
|||||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
image_std = image_std if image_std is not None else self.image_std
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
|
||||||
images = make_list_of_images(images)
|
images = make_flat_list_of_images(images)
|
||||||
if not valid_images(images):
|
if not valid_images(images):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
|||||||
@@ -62,26 +62,22 @@ class Phi4MultimodalProcessor(ProcessorMixin):
|
|||||||
tokenizer_class = "GPT2TokenizerFast"
|
tokenizer_class = "GPT2TokenizerFast"
|
||||||
image_processor_class = "Phi4MultimodalImageProcessorFast"
|
image_processor_class = "Phi4MultimodalImageProcessorFast"
|
||||||
audio_processor_class = "Phi4MultimodalFeatureExtractor"
|
audio_processor_class = "Phi4MultimodalFeatureExtractor"
|
||||||
valid_kwargs = ["chat_template", "fake_image_token_pattern", "fake_audio_token_pattern"]
|
valid_kwargs = ["chat_template"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
image_processor,
|
image_processor,
|
||||||
audio_processor,
|
audio_processor,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
fake_image_token_pattern: str = r"<\|image_\d+\|>",
|
|
||||||
fake_audio_token_pattern: str = r"<\|audio_\d+\|>",
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(image_processor, audio_processor, tokenizer, **kwargs)
|
super().__init__(image_processor, audio_processor, tokenizer, **kwargs)
|
||||||
self.fake_image_token_pattern = fake_image_token_pattern
|
|
||||||
self.fake_audio_token_pattern = fake_audio_token_pattern
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
text: Union[TextInput, List[TextInput]],
|
text: Union[TextInput, List[TextInput]],
|
||||||
images: Optional[ImageInput] = None,
|
images: Optional[ImageInput] = None,
|
||||||
audios: Optional[AudioInput] = None,
|
audio: Optional[AudioInput] = None,
|
||||||
**kwargs: Unpack[ProcessingKwargs],
|
**kwargs: Unpack[ProcessingKwargs],
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
"""
|
"""
|
||||||
@@ -99,7 +95,7 @@ class Phi4MultimodalProcessor(ProcessorMixin):
|
|||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||||
tensor. Both channels-first and channels-last formats are supported.
|
tensor. Both channels-first and channels-last formats are supported.
|
||||||
audios (`List[Union[np.ndarray, torch.Tensor]]`):
|
audio (`List[Union[np.ndarray, torch.Tensor]]`):
|
||||||
List of the audios to be prepared.
|
List of the audios to be prepared.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -120,7 +116,7 @@ class Phi4MultimodalProcessor(ProcessorMixin):
|
|||||||
text_kwargs = output_kwargs["text_kwargs"]
|
text_kwargs = output_kwargs["text_kwargs"]
|
||||||
|
|
||||||
image_inputs = self.image_processor(images, **image_kwargs) if images is not None else {}
|
image_inputs = self.image_processor(images, **image_kwargs) if images is not None else {}
|
||||||
audio_inputs = self.audio_processor(audios, **audio_kwargs) if audios is not None else {}
|
audio_inputs = self.audio_processor(audio, **audio_kwargs) if audio is not None else {}
|
||||||
|
|
||||||
# We pop here for images as we don't need it later
|
# We pop here for images as we don't need it later
|
||||||
num_img_tokens = image_inputs.pop("num_img_tokens", [])
|
num_img_tokens = image_inputs.pop("num_img_tokens", [])
|
||||||
@@ -134,25 +130,25 @@ class Phi4MultimodalProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
image_token = self.tokenizer.image_token
|
image_token = self.tokenizer.image_token
|
||||||
audio_token = self.tokenizer.audio_token
|
audio_token = self.tokenizer.audio_token
|
||||||
processed_text = [re.sub(self.fake_image_token_pattern, image_token, t) for t in text]
|
|
||||||
processed_text = [re.sub(self.fake_audio_token_pattern, audio_token, t) for t in processed_text]
|
|
||||||
|
|
||||||
# Check that the number of special tokens is sound
|
# Check that the number of special tokens is sound
|
||||||
concatenated_prompt = "".join(processed_text)
|
concatenated_prompt = "".join(text)
|
||||||
if concatenated_prompt.count(self.tokenizer.image_token) != len(num_img_tokens):
|
if concatenated_prompt.count(image_token) != len(num_img_tokens):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You should add as much image tokens `<|image_i|>` in your prompt as you pass `images` to the processor"
|
"You should add as much image tokens `<|image|>` in your prompt as you pass `images` to the processor. ",
|
||||||
|
f"Input contains {concatenated_prompt.count(image_token)} tokens != {len(num_img_tokens)} images",
|
||||||
)
|
)
|
||||||
if concatenated_prompt.count(self.tokenizer.audio_token) != len(audio_embed_sizes):
|
if concatenated_prompt.count(audio_token) != len(audio_embed_sizes):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You should add as much audio tokens `<|audio_i|>` in your prompt as you pass `audios` to the processor"
|
"You should add as much audio tokens `<|audio|>` in your prompt as you pass `audios` to the processor. "
|
||||||
|
f"Input contains {concatenated_prompt.count(audio_token)} tokens != {len(audio_embed_sizes)} audios"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add appropriate number of image/audio tokens (note that the count of replacement is dynamic)
|
# Add appropriate number of image/audio tokens (note that the count of replacement is dynamic)
|
||||||
image_count_iter = iter(num_img_tokens)
|
image_count_iter = iter(num_img_tokens)
|
||||||
audio_count_iter = iter(audio_embed_sizes)
|
audio_count_iter = iter(audio_embed_sizes)
|
||||||
processed_text = [
|
processed_text = [
|
||||||
re.sub(re.escape(image_token), lambda _: image_token * next(image_count_iter), t) for t in processed_text
|
re.sub(re.escape(image_token), lambda _: image_token * next(image_count_iter), t) for t in text
|
||||||
]
|
]
|
||||||
processed_text = [
|
processed_text = [
|
||||||
re.sub(re.escape(audio_token), lambda _: audio_token * next(audio_count_iter), t) for t in processed_text
|
re.sub(re.escape(audio_token), lambda _: audio_token * next(audio_count_iter), t) for t in processed_text
|
||||||
|
|||||||
Reference in New Issue
Block a user