Transformers serve VLM (#39454)

* Add support for VLMs in Transformers Serve

* Raushan comments

* Update src/transformers/commands/serving.py

Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

* Quick fix

* CPU -> Auto

* Update src/transformers/commands/serving.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Fixup

---------

Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Lysandre Debut
2025-07-23 17:03:18 +02:00
committed by GitHub
parent ea56eb6bed
commit a0e5a7d34b
2 changed files with 268 additions and 75 deletions

View File

@@ -11,22 +11,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import base64
import copy import copy
import datetime
import enum
import functools import functools
import gc import gc
import io import io
import json import json
import re import re
import tempfile
import threading import threading
import time import time
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field from dataclasses import dataclass, field
from io import BytesIO
from threading import Thread from threading import Thread
from typing import Generator, Optional, Union from typing import Generator, Iterable, Optional, Union
from huggingface_hub import ModelInfo, model_info from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from openai.types.chat import ChatCompletionMessageParam
from PIL import Image
import transformers
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
)
from transformers.utils.import_utils import ( from transformers.utils.import_utils import (
is_fastapi_available, is_fastapi_available,
is_librosa_available, is_librosa_available,
@@ -35,7 +47,13 @@ from transformers.utils.import_utils import (
is_uvicorn_available, is_uvicorn_available,
) )
from .. import LogitsProcessorList, PreTrainedTokenizerFast, ProcessorMixin, TextIteratorStreamer from .. import (
AutoConfig,
LogitsProcessorList,
PreTrainedTokenizerFast,
ProcessorMixin,
TextIteratorStreamer,
)
from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus
from ..utils import is_torch_available, logging from ..utils import is_torch_available, logging
from . import BaseTransformersCLICommand from . import BaseTransformersCLICommand
@@ -45,8 +63,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
AutoModelForCausalLM,
AutoModelForSpeechSeq2Seq,
AutoProcessor, AutoProcessor,
BitsAndBytesConfig, BitsAndBytesConfig,
GenerationConfig, GenerationConfig,
@@ -187,6 +203,13 @@ _TOOL_CALL_TOKENS = {
_MODELS_WITH_TOOL_SUPPORT = list(_TOOL_CALL_TOKENS.keys()) _MODELS_WITH_TOOL_SUPPORT = list(_TOOL_CALL_TOKENS.keys())
class Modality(enum.Enum):
LLM = "LLM"
VLM = "VLM"
STT = "STT"
TTS = "TTS"
def serve_command_factory(args: Namespace): def serve_command_factory(args: Namespace):
""" """
Factory function used to instantiate serving server from provided command line arguments. Factory function used to instantiate serving server from provided command line arguments.
@@ -271,7 +294,7 @@ class ToolState:
class TimedModel: class TimedModel:
""" """
A class that holds a PreTrainedModel instance and its associated processor (tokenizer, audio processor, etc.). A class that holds a PreTrainedModel instance and its associated processor.
Automatically deletes the instances after a specified timeout. Automatically deletes the instances after a specified timeout.
""" """
@@ -325,7 +348,13 @@ class ServeArguments:
`transformers serve --help` `transformers serve --help`
""" """
device: str = field(default="cpu", metadata={"help": "Device to use for inference."}) device: str = field(
default="auto",
metadata={
"help": "Device to use for inference; will default to `auto` and"
"place the model on an accelerator if available."
},
)
torch_dtype: Optional[str] = field( torch_dtype: Optional[str] = field(
default="auto", default="auto",
metadata={ metadata={
@@ -438,7 +467,7 @@ class ServeCommand(BaseTransformersCLICommand):
# cache and avoid re-running prefil # cache and avoid re-running prefil
self.last_messages = None self.last_messages = None
self.last_kv_cache = None self.last_kv_cache = None
self.last_text_model = None self.last_model = None
def _validate_request( def _validate_request(
self, self,
@@ -632,27 +661,15 @@ class ServeCommand(BaseTransformersCLICommand):
output = self.generate_transcription(parsed_request) output = self.generate_transcription(parsed_request)
return StreamingResponse(output, media_type="text/event-stream") return StreamingResponse(output, media_type="text/event-stream")
@app.options("/v1/models")
@app.get("/v1/models") @app.get("/v1/models")
def get_all_models(): def get_all_models():
return JSONResponse( return JSONResponse({"object": "list", "data": self.get_gen_models()})
{
"object": "list",
"data": [
{
"id": model.id,
"object": "model",
"created": model.created_at.timestamp(),
"owned_by": model.author,
}
for model in self.get_text_gen_models()
],
}
)
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level) uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def get_text_gen_models(self) -> list[ModelInfo]: def get_gen_models(self) -> list[dict[str, any]]:
""" """
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
model working with generate can work. model working with generate can work.
@@ -660,16 +677,42 @@ class ServeCommand(BaseTransformersCLICommand):
This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party
integrations. integrations.
""" """
models = [
"Menlo/Jan-nano",
"Menlo/Jan-nano-128k",
"Qwen/Qwen2.5-0.5B-Instruct",
"Qwen/Qwen2.5-3B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-14B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.3-70B-Instruct",
"HuggingFaceTB/SmolVLM-Instruct",
"ibm-granite/granite-vision-3.2-2b",
"Qwen/Qwen2.5-VL-7B-Instruct",
"OpenGVLab/InternVL3-1B",
]
if HF_HUB_OFFLINE:
return [ return [
model_info("Menlo/Jan-nano"), {
model_info("Menlo/Jan-nano-128k"), "id": model,
model_info("Qwen/Qwen2.5-0.5B-Instruct"), "object": "model",
model_info("Qwen/Qwen2.5-3B-Instruct"), "created": datetime.datetime.now().timestamp(),
model_info("Qwen/Qwen2.5-7B-Instruct"), "owned_by": model.split("/")[0],
model_info("Qwen/Qwen2.5-14B-Instruct"), }
model_info("meta-llama/Llama-3.1-8B-Instruct"), for model in models
model_info("meta-llama/Llama-3.2-1B-Instruct"), ]
model_info("meta-llama/Llama-3.3-70B-Instruct"), else:
model_infos = [model_info(model) for model in models]
return [
{
"id": model.id,
"object": "model",
"created": model.created_at.timestamp(),
"owned_by": model.author,
}
for model in model_infos
] ]
def continuous_batching_chat_completion(self, req: dict) -> Generator[str, None, None]: def continuous_batching_chat_completion(self, req: dict) -> Generator[str, None, None]:
@@ -684,14 +727,16 @@ class ServeCommand(BaseTransformersCLICommand):
""" """
model_id_and_revision = self.process_model_name(req["model"]) model_id_and_revision = self.process_model_name(req["model"])
must_discard_cache = model_id_and_revision != self.last_text_model must_discard_cache = model_id_and_revision != self.last_model
self.last_text_model = model_id_and_revision self.last_model = model_id_and_revision
if must_discard_cache: if must_discard_cache:
# When switching models, terminate a continuous batching manager if it is running. # When switching models, terminate a continuous batching manager if it is running.
if self.running_continuous_batching_manager is not None: if self.running_continuous_batching_manager is not None:
self.running_continuous_batching_manager.stop(block=True, timeout=2) self.running_continuous_batching_manager.stop(block=True, timeout=2)
self.running_continuous_batching_manager = None self.running_continuous_batching_manager = None
model, tokenizer = self.load_text_model_and_tokenizer(model_id_and_revision) model, processor = self.load_model_and_processor(model_id_and_revision)
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
generation_config = create_generation_config_from_req( generation_config = create_generation_config_from_req(
req, req,
@@ -717,7 +762,7 @@ class ServeCommand(BaseTransformersCLICommand):
self.running_continuous_batching_manager.start() self.running_continuous_batching_manager.start()
# TODO (Joao, Lysandre): this should also work with tool support # TODO (Joao, Lysandre): this should also work with tool support
inputs = tokenizer.apply_chat_template(req["messages"], return_tensors="pt", add_generation_prompt=True).to( inputs = processor.apply_chat_template(req["messages"], return_tensors="pt", add_generation_prompt=True).to(
model.device model.device
) )
@@ -759,6 +804,50 @@ class ServeCommand(BaseTransformersCLICommand):
return stream_chat_completion(inputs[0]) return stream_chat_completion(inputs[0])
@staticmethod
def get_model_modality(model: PreTrainedModel) -> Modality:
model_classname = model.__class__.__name__
if model_classname in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
modality = Modality.VLM
elif model_classname in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
modality = Modality.LLM
else:
raise ValueError(f"Unknown modality: {model_classname}")
return modality
@staticmethod
def get_processor_inputs_from_inbound_messages(messages, modality: Modality):
processor_inputs = []
for message in messages:
parsed_message = {"role": message["role"], "content": []}
if modality == Modality.LLM:
# If we're working with LLMs, then "content" is a single string.
content = message["content"] if isinstance(message["content"], str) else message["content"]["text"]
parsed_message["content"] = content
elif modality == Modality.VLM:
# If we're working with VLMs, then "content" is a dictionary, containing a "type" key indicating
# which other key will be present and the type of the value of said key.
if isinstance(message["content"], str):
parsed_message["content"].append({"type": "text", "text": message["content"]})
else:
for content in message["content"]:
if content["type"] == "text":
parsed_message["content"].append(content)
elif content["type"] == "image_url":
image_data = re.sub("^data:image/.+;base64,", "", content["image_url"]["url"])
image = Image.open(BytesIO(base64.b64decode(image_data)))
file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
image.save(file.name)
parsed_message["content"].append({"type": "image", "url": file.name})
processor_inputs.append(parsed_message)
return processor_inputs
def generate_chat_completion(self, req: dict) -> Generator[str, None, None]: def generate_chat_completion(self, req: dict) -> Generator[str, None, None]:
""" """
Generates an OpenAI Chat Completion using `generate`. Generates an OpenAI Chat Completion using `generate`.
@@ -769,15 +858,24 @@ class ServeCommand(BaseTransformersCLICommand):
Returns: Returns:
`Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks. `Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks.
""" """
if self.args.force_model is not None:
req["model"] = self.args.force_model
messages: Iterable[ChatCompletionMessageParam] = req["messages"]
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a # HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
# request whose last message is from the assistant. # request whose last message is from the assistant.
if req["messages"][-1]["role"] == "assistant": if messages[-1]["role"] == "assistant":
return return
model_id_and_revision = self.process_model_name(req["model"]) model_id_and_revision = self.process_model_name(req["model"])
must_discard_cache = model_id_and_revision != self.last_text_model must_discard_cache = model_id_and_revision != self.last_model
self.last_text_model = model_id_and_revision
model, tokenizer = self.load_text_model_and_tokenizer(model_id_and_revision) self.last_model = model_id_and_revision
model, processor = self.load_model_and_processor(model_id_and_revision)
modality = self.get_model_modality(model)
processor_inputs = self.get_processor_inputs_from_inbound_messages(messages, modality)
# ====== TOOL PREPROCESSING LOGIC ====== # ====== TOOL PREPROCESSING LOGIC ======
tool_model_family = None tool_model_family = None
@@ -790,16 +888,18 @@ class ServeCommand(BaseTransformersCLICommand):
# 2. force generation to pick from that tool's arguments # 2. force generation to pick from that tool's arguments
# ====== END OF TOOL PREPROCESSING LOGIC ====== # ====== END OF TOOL PREPROCESSING LOGIC ======
if tool_model_family is not None: inputs = processor.apply_chat_template(
text = tokenizer.apply_chat_template( processor_inputs,
req["messages"], add_generation_prompt=True, tokenize=False, tools=req.get("tools") add_generation_prompt=True,
tools=req.get("tools", None),
return_tensors="pt",
return_dict=True,
tokenize=True,
) )
else: inputs = inputs.to(model.device)
text = tokenizer.apply_chat_template(req["messages"], add_generation_prompt=True, tokenize=False)
inputs = tokenizer(text, return_tensors="pt").to(model.device)["input_ids"]
request_id = req.get("request_id", "req_0") request_id = req.get("request_id", "req_0")
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config) generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
last_kv_cache = None last_kv_cache = None
@@ -807,8 +907,7 @@ class ServeCommand(BaseTransformersCLICommand):
last_kv_cache = self.last_kv_cache last_kv_cache = self.last_kv_cache
generation_kwargs = { generation_kwargs = {
"inputs": inputs, **inputs,
"attention_mask": torch.ones_like(inputs),
"streamer": generation_streamer, "streamer": generation_streamer,
"generation_config": generation_config, "generation_config": generation_config,
"return_dict_in_generate": True, "return_dict_in_generate": True,
@@ -929,15 +1028,14 @@ class ServeCommand(BaseTransformersCLICommand):
""" """
# TODO -- Implement non-streaming mode # TODO -- Implement non-streaming mode
model_id_and_revision = self.process_model_name(req["model"]) model_id_and_revision = self.process_model_name(req["model"])
must_discard_cache = model_id_and_revision != self.last_text_model must_discard_cache = model_id_and_revision != self.last_model
self.last_text_model = model_id_and_revision self.last_model = model_id_and_revision
model, tokenizer = self.load_text_model_and_tokenizer(model_id_and_revision) model, processor = self.load_model_and_processor(model_id_and_revision)
text = tokenizer.apply_chat_template(req["input"], add_generation_prompt=True, tokenize=False) inputs = processor.apply_chat_template(req["input"], add_generation_prompt=True).to(model.device)
inputs = tokenizer(text, return_tensors="pt").to(model.device)["input_ids"]
request_id = req.get("previous_response_id", "req_0") request_id = req.get("previous_response_id", "req_0")
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config) generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
last_kv_cache = None last_kv_cache = None
@@ -1282,9 +1380,7 @@ class ServeCommand(BaseTransformersCLICommand):
return model_id return model_id
return f"{model_id}@main" return f"{model_id}@main"
def _load_model_and_data_processor( def _load_model_and_data_processor(self, model_id_and_revision: str):
self, model_id_and_revision: str, model_cls: type[PreTrainedModel]
) -> tuple[PreTrainedModel, Union[ProcessorMixin, PreTrainedTokenizerFast]]:
""" """
Generic method to load a model and a data processor from a model ID and revision, making use of the serve CLI Generic method to load a model and a data processor from a model ID and revision, making use of the serve CLI
arguments. arguments.
@@ -1325,7 +1421,9 @@ class ServeCommand(BaseTransformersCLICommand):
"trust_remote_code": args.trust_remote_code, "trust_remote_code": args.trust_remote_code,
} }
model = model_cls.from_pretrained(model_id, **model_kwargs) config = AutoConfig.from_pretrained(model_id, **model_kwargs)
architecture = getattr(transformers, config.architectures[0])
model = architecture.from_pretrained(model_id, **model_kwargs)
if getattr(model, "hf_device_map", None) is None: if getattr(model, "hf_device_map", None) is None:
model = model.to(args.device) model = model.to(args.device)
@@ -1342,32 +1440,30 @@ class ServeCommand(BaseTransformersCLICommand):
logger.info(f"Loaded model {model_id_and_revision}") logger.info(f"Loaded model {model_id_and_revision}")
return model, data_processor return model, data_processor
def load_text_model_and_tokenizer( def load_model_and_processor(self, model_id_and_revision: str) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
self, model_id_and_revision: str
) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
""" """
Loads the text model and tokenizer from the given model ID and revision into the ServeCommand instance. Loads the text model and processor from the given model ID and revision into the ServeCommand instance.
Args: Args:
model_id_and_revision (`str`): model_id_and_revision (`str`):
The model ID and revision to load. The model ID and revision to load.
Returns: Returns:
`tuple[PreTrainedModel, PreTrainedTokenizerFast]`: The loaded text model and tokenizer. `tuple[PreTrainedModel, PreTrainedTokenizerFast]`: The loaded text model and processor.
""" """
if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted(): if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted():
model, tokenizer = self._load_model_and_data_processor(model_id_and_revision, AutoModelForCausalLM) model, processor = self._load_model_and_data_processor(model_id_and_revision)
self.loaded_models[model_id_and_revision] = TimedModel( self.loaded_models[model_id_and_revision] = TimedModel(
model, model,
timeout_seconds=self.args.model_timeout, timeout_seconds=self.args.model_timeout,
processor=tokenizer, processor=processor,
) )
else: else:
self.loaded_models[model_id_and_revision].reset_timer() self.loaded_models[model_id_and_revision].reset_timer()
model = self.loaded_models[model_id_and_revision].model model = self.loaded_models[model_id_and_revision].model
tokenizer = self.loaded_models[model_id_and_revision].processor processor = self.loaded_models[model_id_and_revision].processor
return model, tokenizer return model, processor
def load_audio_model_and_processor(self, model_id_and_revision: str) -> tuple[PreTrainedModel, ProcessorMixin]: def load_audio_model_and_processor(self, model_id_and_revision: str) -> tuple[PreTrainedModel, ProcessorMixin]:
""" """
@@ -1381,9 +1477,7 @@ class ServeCommand(BaseTransformersCLICommand):
`tuple[PreTrainedModel, ProcessorMixin]`: The loaded audio model and processor. `tuple[PreTrainedModel, ProcessorMixin]`: The loaded audio model and processor.
""" """
if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted(): if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted():
audio_model, audio_processor = self._load_model_and_data_processor( audio_model, audio_processor = self._load_model_and_data_processor(model_id_and_revision)
model_id_and_revision, AutoModelForSpeechSeq2Seq
)
self.loaded_models[model_id_and_revision] = TimedModel( self.loaded_models[model_id_and_revision] = TimedModel(
audio_model, audio_model,
timeout_seconds=self.args.model_timeout, timeout_seconds=self.args.model_timeout,

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import os
import time import time
import unittest import unittest
from threading import Thread from threading import Thread
@@ -23,7 +24,7 @@ from parameterized import parameterized
import transformers.commands.transformers_cli as cli import transformers.commands.transformers_cli as cli
from transformers import GenerationConfig from transformers import GenerationConfig
from transformers.commands.serving import ServeArguments, ServeCommand from transformers.commands.serving import Modality, ServeArguments, ServeCommand
from transformers.testing_utils import CaptureStd, require_openai, slow from transformers.testing_utils import CaptureStd, require_openai, slow
from transformers.utils.import_utils import is_openai_available from transformers.utils.import_utils import is_openai_available
@@ -258,6 +259,104 @@ class ServeCompletionsMixin:
# TODO: speed-based test to confirm that KV cache is working across requests # TODO: speed-based test to confirm that KV cache is working across requests
class ServeCompletionsGenerateMockTests(unittest.TestCase):
def test_processor_inputs_from_inbound_messages_llm(self):
modality = Modality.LLM
messages = expected_outputs = [
{"role": "user", "content": "How are you doing?"},
{"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"},
{"role": "user", "content": "Can you help me write tests?"},
]
outputs = ServeCommand.get_processor_inputs_from_inbound_messages(messages, modality)
self.assertListEqual(expected_outputs, outputs)
def test_processor_inputs_from_inbound_messages_vlm_text_only(self):
modality = Modality.VLM
messages = [
{"role": "user", "content": "How are you doing?"},
{"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"},
{"role": "user", "content": "Can you help me write tests?"},
]
expected_outputs = [
{"role": "user", "content": [{"type": "text", "text": "How are you doing?"}]},
{
"role": "assistant",
"content": [
{"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"}
],
},
{"role": "user", "content": [{"type": "text", "text": "Can you help me write tests?"}]},
]
outputs = ServeCommand.get_processor_inputs_from_inbound_messages(messages, modality)
self.assertListEqual(expected_outputs, outputs)
def test_processor_inputs_from_inbound_messages_vlm_text_and_image_in_base_64(self):
modality = Modality.VLM
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "How many pixels are in the image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAASABIAAD/4QBARXhpZgAATU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAABaADAAQAAAABAAAABQAAAAD/7QA4UGhvdG9zaG9wIDMuMAA4QklNBAQAAAAAAAA4QklNBCUAAAAAABDUHYzZjwCyBOmACZjs+EJ+/8AAEQgABQAFAwEiAAIRAQMRAf/EAB8AAAEFAQEBAQEBAAAAAAAAAAABAgMEBQYHCAkKC//EALUQAAIBAwMCBAMFBQQEAAABfQECAwAEEQUSITFBBhNRYQcicRQygZGhCCNCscEVUtHwJDNicoIJChYXGBkaJSYnKCkqNDU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6g4SFhoeIiYqSk5SVlpeYmZqio6Slpqeoqaqys7S1tre4ubrCw8TFxsfIycrS09TV1tfY2drh4uPk5ebn6Onq8fLz9PX29/j5+v/EAB8BAAMBAQEBAQEBAQEAAAAAAAABAgMEBQYHCAkKC//EALURAAIBAgQEAwQHBQQEAAECdwABAgMRBAUhMQYSQVEHYXETIjKBCBRCkaGxwQkjM1LwFWJy0QoWJDThJfEXGBkaJicoKSo1Njc4OTpDREVGR0hJSlNUVVZXWFlaY2RlZmdoaWpzdHV2d3h5eoKDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uLj5OXm5+jp6vLz9PX29/j5+v/bAEMAAQEBAQEBAgEBAgICAgICAwICAgIDBAMDAwMDBAUEBAQEBAQFBQUFBQUFBQYGBgYGBgcHBwcHCAgICAgICAgICP/bAEMBAQEBAgICAwICAwgFBAUICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICP/dAAQAAf/aAAwDAQACEQMRAD8A/v4ooooA/9k="
},
},
],
},
{
"role": "assistant",
"content": "The number of pixels in the image cannot be determined from the provided information.",
},
{"role": "user", "content": "Alright"},
]
expected_outputs = [
{
"role": "user",
"content": [
{"type": "text", "text": "How many pixels are in the image?"},
{"type": "image", "url": "/var/folders/4v/64sxdhsd3gz3r8vhhnyc0mqw0000gn/T/tmp50oyghk6.png"},
],
},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "The number of pixels in the image cannot be determined from the provided information.",
}
],
},
{"role": "user", "content": [{"type": "text", "text": "Alright"}]},
]
outputs = ServeCommand.get_processor_inputs_from_inbound_messages(messages, modality)
for expected_output, output in zip(expected_outputs, outputs):
expected_output_content = expected_output["content"]
output_content = output["content"]
self.assertEqual(type(expected_output_content), type(output_content))
if isinstance(expected_output_content, list):
for expected_output_content_item, output_content_item in zip(expected_output_content, output_content):
self.assertIn("type", expected_output_content_item)
self.assertIn("type", output_content_item)
self.assertTrue(expected_output_content_item["type"] == output_content_item["type"])
if expected_output_content_item["type"] == "text":
self.assertEqual(expected_output_content_item["text"], output_content_item["text"])
if expected_output_content_item["type"] == "image":
self.assertTrue(os.path.exists(output_content_item["url"]))
else:
raise ValueError("VLMs should only receive content as lists.")
@slow # server startup time is slow on our push CI @slow # server startup time is slow on our push CI
@require_openai @require_openai
class ServeCompletionsGenerateIntegrationTest(ServeCompletionsMixin, unittest.TestCase): class ServeCompletionsGenerateIntegrationTest(ServeCompletionsMixin, unittest.TestCase):