diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index 3209f3c8ae..ccf35dbd3b 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -11,22 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import base64 import copy +import datetime +import enum import functools import gc import io import json import re +import tempfile import threading import time from argparse import ArgumentParser, Namespace from dataclasses import dataclass, field +from io import BytesIO from threading import Thread -from typing import Generator, Optional, Union +from typing import Generator, Iterable, Optional, Union -from huggingface_hub import ModelInfo, model_info +from huggingface_hub import model_info +from huggingface_hub.constants import HF_HUB_OFFLINE +from openai.types.chat import ChatCompletionMessageParam +from PIL import Image +import transformers +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, +) from transformers.utils.import_utils import ( is_fastapi_available, is_librosa_available, @@ -35,7 +47,13 @@ from transformers.utils.import_utils import ( is_uvicorn_available, ) -from .. import LogitsProcessorList, PreTrainedTokenizerFast, ProcessorMixin, TextIteratorStreamer +from .. import ( + AutoConfig, + LogitsProcessorList, + PreTrainedTokenizerFast, + ProcessorMixin, + TextIteratorStreamer, +) from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus from ..utils import is_torch_available, logging from . import BaseTransformersCLICommand @@ -45,8 +63,6 @@ if is_torch_available(): import torch from transformers import ( - AutoModelForCausalLM, - AutoModelForSpeechSeq2Seq, AutoProcessor, BitsAndBytesConfig, GenerationConfig, @@ -187,6 +203,13 @@ _TOOL_CALL_TOKENS = { _MODELS_WITH_TOOL_SUPPORT = list(_TOOL_CALL_TOKENS.keys()) +class Modality(enum.Enum): + LLM = "LLM" + VLM = "VLM" + STT = "STT" + TTS = "TTS" + + def serve_command_factory(args: Namespace): """ Factory function used to instantiate serving server from provided command line arguments. @@ -271,7 +294,7 @@ class ToolState: class TimedModel: """ - A class that holds a PreTrainedModel instance and its associated processor (tokenizer, audio processor, etc.). + A class that holds a PreTrainedModel instance and its associated processor. Automatically deletes the instances after a specified timeout. """ @@ -325,7 +348,13 @@ class ServeArguments: `transformers serve --help` """ - device: str = field(default="cpu", metadata={"help": "Device to use for inference."}) + device: str = field( + default="auto", + metadata={ + "help": "Device to use for inference; will default to `auto` and" + "place the model on an accelerator if available." + }, + ) torch_dtype: Optional[str] = field( default="auto", metadata={ @@ -438,7 +467,7 @@ class ServeCommand(BaseTransformersCLICommand): # cache and avoid re-running prefil self.last_messages = None self.last_kv_cache = None - self.last_text_model = None + self.last_model = None def _validate_request( self, @@ -632,27 +661,15 @@ class ServeCommand(BaseTransformersCLICommand): output = self.generate_transcription(parsed_request) return StreamingResponse(output, media_type="text/event-stream") + @app.options("/v1/models") @app.get("/v1/models") def get_all_models(): - return JSONResponse( - { - "object": "list", - "data": [ - { - "id": model.id, - "object": "model", - "created": model.created_at.timestamp(), - "owned_by": model.author, - } - for model in self.get_text_gen_models() - ], - } - ) + return JSONResponse({"object": "list", "data": self.get_gen_models()}) uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level) @functools.lru_cache(maxsize=None) - def get_text_gen_models(self) -> list[ModelInfo]: + def get_gen_models(self) -> list[dict[str, any]]: """ This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based model working with generate can work. @@ -660,18 +677,44 @@ class ServeCommand(BaseTransformersCLICommand): This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party integrations. """ - return [ - model_info("Menlo/Jan-nano"), - model_info("Menlo/Jan-nano-128k"), - model_info("Qwen/Qwen2.5-0.5B-Instruct"), - model_info("Qwen/Qwen2.5-3B-Instruct"), - model_info("Qwen/Qwen2.5-7B-Instruct"), - model_info("Qwen/Qwen2.5-14B-Instruct"), - model_info("meta-llama/Llama-3.1-8B-Instruct"), - model_info("meta-llama/Llama-3.2-1B-Instruct"), - model_info("meta-llama/Llama-3.3-70B-Instruct"), + models = [ + "Menlo/Jan-nano", + "Menlo/Jan-nano-128k", + "Qwen/Qwen2.5-0.5B-Instruct", + "Qwen/Qwen2.5-3B-Instruct", + "Qwen/Qwen2.5-7B-Instruct", + "Qwen/Qwen2.5-14B-Instruct", + "meta-llama/Llama-3.1-8B-Instruct", + "meta-llama/Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.3-70B-Instruct", + "HuggingFaceTB/SmolVLM-Instruct", + "ibm-granite/granite-vision-3.2-2b", + "Qwen/Qwen2.5-VL-7B-Instruct", + "OpenGVLab/InternVL3-1B", ] + if HF_HUB_OFFLINE: + return [ + { + "id": model, + "object": "model", + "created": datetime.datetime.now().timestamp(), + "owned_by": model.split("/")[0], + } + for model in models + ] + else: + model_infos = [model_info(model) for model in models] + return [ + { + "id": model.id, + "object": "model", + "created": model.created_at.timestamp(), + "owned_by": model.author, + } + for model in model_infos + ] + def continuous_batching_chat_completion(self, req: dict) -> Generator[str, None, None]: """ Generates an OpenAI Chat Completion using continuous batching. @@ -684,14 +727,16 @@ class ServeCommand(BaseTransformersCLICommand): """ model_id_and_revision = self.process_model_name(req["model"]) - must_discard_cache = model_id_and_revision != self.last_text_model - self.last_text_model = model_id_and_revision + must_discard_cache = model_id_and_revision != self.last_model + self.last_model = model_id_and_revision if must_discard_cache: # When switching models, terminate a continuous batching manager if it is running. if self.running_continuous_batching_manager is not None: self.running_continuous_batching_manager.stop(block=True, timeout=2) self.running_continuous_batching_manager = None - model, tokenizer = self.load_text_model_and_tokenizer(model_id_and_revision) + model, processor = self.load_model_and_processor(model_id_and_revision) + + tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor generation_config = create_generation_config_from_req( req, @@ -717,7 +762,7 @@ class ServeCommand(BaseTransformersCLICommand): self.running_continuous_batching_manager.start() # TODO (Joao, Lysandre): this should also work with tool support - inputs = tokenizer.apply_chat_template(req["messages"], return_tensors="pt", add_generation_prompt=True).to( + inputs = processor.apply_chat_template(req["messages"], return_tensors="pt", add_generation_prompt=True).to( model.device ) @@ -759,6 +804,50 @@ class ServeCommand(BaseTransformersCLICommand): return stream_chat_completion(inputs[0]) + @staticmethod + def get_model_modality(model: PreTrainedModel) -> Modality: + model_classname = model.__class__.__name__ + if model_classname in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values(): + modality = Modality.VLM + elif model_classname in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + modality = Modality.LLM + else: + raise ValueError(f"Unknown modality: {model_classname}") + + return modality + + @staticmethod + def get_processor_inputs_from_inbound_messages(messages, modality: Modality): + processor_inputs = [] + + for message in messages: + parsed_message = {"role": message["role"], "content": []} + + if modality == Modality.LLM: + # If we're working with LLMs, then "content" is a single string. + content = message["content"] if isinstance(message["content"], str) else message["content"]["text"] + parsed_message["content"] = content + + elif modality == Modality.VLM: + # If we're working with VLMs, then "content" is a dictionary, containing a "type" key indicating + # which other key will be present and the type of the value of said key. + if isinstance(message["content"], str): + parsed_message["content"].append({"type": "text", "text": message["content"]}) + else: + for content in message["content"]: + if content["type"] == "text": + parsed_message["content"].append(content) + elif content["type"] == "image_url": + image_data = re.sub("^data:image/.+;base64,", "", content["image_url"]["url"]) + image = Image.open(BytesIO(base64.b64decode(image_data))) + + file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + image.save(file.name) + + parsed_message["content"].append({"type": "image", "url": file.name}) + processor_inputs.append(parsed_message) + return processor_inputs + def generate_chat_completion(self, req: dict) -> Generator[str, None, None]: """ Generates an OpenAI Chat Completion using `generate`. @@ -769,15 +858,24 @@ class ServeCommand(BaseTransformersCLICommand): Returns: `Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks. """ + if self.args.force_model is not None: + req["model"] = self.args.force_model + + messages: Iterable[ChatCompletionMessageParam] = req["messages"] + # HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a # request whose last message is from the assistant. - if req["messages"][-1]["role"] == "assistant": + if messages[-1]["role"] == "assistant": return model_id_and_revision = self.process_model_name(req["model"]) - must_discard_cache = model_id_and_revision != self.last_text_model - self.last_text_model = model_id_and_revision - model, tokenizer = self.load_text_model_and_tokenizer(model_id_and_revision) + must_discard_cache = model_id_and_revision != self.last_model + + self.last_model = model_id_and_revision + model, processor = self.load_model_and_processor(model_id_and_revision) + + modality = self.get_model_modality(model) + processor_inputs = self.get_processor_inputs_from_inbound_messages(messages, modality) # ====== TOOL PREPROCESSING LOGIC ====== tool_model_family = None @@ -790,16 +888,18 @@ class ServeCommand(BaseTransformersCLICommand): # 2. force generation to pick from that tool's arguments # ====== END OF TOOL PREPROCESSING LOGIC ====== - if tool_model_family is not None: - text = tokenizer.apply_chat_template( - req["messages"], add_generation_prompt=True, tokenize=False, tools=req.get("tools") - ) - else: - text = tokenizer.apply_chat_template(req["messages"], add_generation_prompt=True, tokenize=False) - inputs = tokenizer(text, return_tensors="pt").to(model.device)["input_ids"] + inputs = processor.apply_chat_template( + processor_inputs, + add_generation_prompt=True, + tools=req.get("tools", None), + return_tensors="pt", + return_dict=True, + tokenize=True, + ) + inputs = inputs.to(model.device) request_id = req.get("request_id", "req_0") - generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) + generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True) generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config) last_kv_cache = None @@ -807,8 +907,7 @@ class ServeCommand(BaseTransformersCLICommand): last_kv_cache = self.last_kv_cache generation_kwargs = { - "inputs": inputs, - "attention_mask": torch.ones_like(inputs), + **inputs, "streamer": generation_streamer, "generation_config": generation_config, "return_dict_in_generate": True, @@ -929,15 +1028,14 @@ class ServeCommand(BaseTransformersCLICommand): """ # TODO -- Implement non-streaming mode model_id_and_revision = self.process_model_name(req["model"]) - must_discard_cache = model_id_and_revision != self.last_text_model - self.last_text_model = model_id_and_revision - model, tokenizer = self.load_text_model_and_tokenizer(model_id_and_revision) + must_discard_cache = model_id_and_revision != self.last_model + self.last_model = model_id_and_revision + model, processor = self.load_model_and_processor(model_id_and_revision) - text = tokenizer.apply_chat_template(req["input"], add_generation_prompt=True, tokenize=False) - inputs = tokenizer(text, return_tensors="pt").to(model.device)["input_ids"] + inputs = processor.apply_chat_template(req["input"], add_generation_prompt=True).to(model.device) request_id = req.get("previous_response_id", "req_0") - generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) + generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True) generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config) last_kv_cache = None @@ -1282,9 +1380,7 @@ class ServeCommand(BaseTransformersCLICommand): return model_id return f"{model_id}@main" - def _load_model_and_data_processor( - self, model_id_and_revision: str, model_cls: type[PreTrainedModel] - ) -> tuple[PreTrainedModel, Union[ProcessorMixin, PreTrainedTokenizerFast]]: + def _load_model_and_data_processor(self, model_id_and_revision: str): """ Generic method to load a model and a data processor from a model ID and revision, making use of the serve CLI arguments. @@ -1325,7 +1421,9 @@ class ServeCommand(BaseTransformersCLICommand): "trust_remote_code": args.trust_remote_code, } - model = model_cls.from_pretrained(model_id, **model_kwargs) + config = AutoConfig.from_pretrained(model_id, **model_kwargs) + architecture = getattr(transformers, config.architectures[0]) + model = architecture.from_pretrained(model_id, **model_kwargs) if getattr(model, "hf_device_map", None) is None: model = model.to(args.device) @@ -1342,32 +1440,30 @@ class ServeCommand(BaseTransformersCLICommand): logger.info(f"Loaded model {model_id_and_revision}") return model, data_processor - def load_text_model_and_tokenizer( - self, model_id_and_revision: str - ) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]: + def load_model_and_processor(self, model_id_and_revision: str) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]: """ - Loads the text model and tokenizer from the given model ID and revision into the ServeCommand instance. + Loads the text model and processor from the given model ID and revision into the ServeCommand instance. Args: model_id_and_revision (`str`): The model ID and revision to load. Returns: - `tuple[PreTrainedModel, PreTrainedTokenizerFast]`: The loaded text model and tokenizer. + `tuple[PreTrainedModel, PreTrainedTokenizerFast]`: The loaded text model and processor. """ if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted(): - model, tokenizer = self._load_model_and_data_processor(model_id_and_revision, AutoModelForCausalLM) + model, processor = self._load_model_and_data_processor(model_id_and_revision) self.loaded_models[model_id_and_revision] = TimedModel( model, timeout_seconds=self.args.model_timeout, - processor=tokenizer, + processor=processor, ) else: self.loaded_models[model_id_and_revision].reset_timer() model = self.loaded_models[model_id_and_revision].model - tokenizer = self.loaded_models[model_id_and_revision].processor + processor = self.loaded_models[model_id_and_revision].processor - return model, tokenizer + return model, processor def load_audio_model_and_processor(self, model_id_and_revision: str) -> tuple[PreTrainedModel, ProcessorMixin]: """ @@ -1381,9 +1477,7 @@ class ServeCommand(BaseTransformersCLICommand): `tuple[PreTrainedModel, ProcessorMixin]`: The loaded audio model and processor. """ if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted(): - audio_model, audio_processor = self._load_model_and_data_processor( - model_id_and_revision, AutoModelForSpeechSeq2Seq - ) + audio_model, audio_processor = self._load_model_and_data_processor(model_id_and_revision) self.loaded_models[model_id_and_revision] = TimedModel( audio_model, timeout_seconds=self.args.model_timeout, diff --git a/tests/commands/test_serving.py b/tests/commands/test_serving.py index ed344ef7ed..f0592686e8 100644 --- a/tests/commands/test_serving.py +++ b/tests/commands/test_serving.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +import os import time import unittest from threading import Thread @@ -23,7 +24,7 @@ from parameterized import parameterized import transformers.commands.transformers_cli as cli from transformers import GenerationConfig -from transformers.commands.serving import ServeArguments, ServeCommand +from transformers.commands.serving import Modality, ServeArguments, ServeCommand from transformers.testing_utils import CaptureStd, require_openai, slow from transformers.utils.import_utils import is_openai_available @@ -258,6 +259,104 @@ class ServeCompletionsMixin: # TODO: speed-based test to confirm that KV cache is working across requests +class ServeCompletionsGenerateMockTests(unittest.TestCase): + def test_processor_inputs_from_inbound_messages_llm(self): + modality = Modality.LLM + messages = expected_outputs = [ + {"role": "user", "content": "How are you doing?"}, + {"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"}, + {"role": "user", "content": "Can you help me write tests?"}, + ] + outputs = ServeCommand.get_processor_inputs_from_inbound_messages(messages, modality) + self.assertListEqual(expected_outputs, outputs) + + def test_processor_inputs_from_inbound_messages_vlm_text_only(self): + modality = Modality.VLM + messages = [ + {"role": "user", "content": "How are you doing?"}, + {"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"}, + {"role": "user", "content": "Can you help me write tests?"}, + ] + + expected_outputs = [ + {"role": "user", "content": [{"type": "text", "text": "How are you doing?"}]}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"} + ], + }, + {"role": "user", "content": [{"type": "text", "text": "Can you help me write tests?"}]}, + ] + + outputs = ServeCommand.get_processor_inputs_from_inbound_messages(messages, modality) + self.assertListEqual(expected_outputs, outputs) + + def test_processor_inputs_from_inbound_messages_vlm_text_and_image_in_base_64(self): + modality = Modality.VLM + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "How many pixels are in the image?"}, + { + "type": "image_url", + "image_url": { + "url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAASABIAAD/4QBARXhpZgAATU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAABaADAAQAAAABAAAABQAAAAD/7QA4UGhvdG9zaG9wIDMuMAA4QklNBAQAAAAAAAA4QklNBCUAAAAAABDUHYzZjwCyBOmACZjs+EJ+/8AAEQgABQAFAwEiAAIRAQMRAf/EAB8AAAEFAQEBAQEBAAAAAAAAAAABAgMEBQYHCAkKC//EALUQAAIBAwMCBAMFBQQEAAABfQECAwAEEQUSITFBBhNRYQcicRQygZGhCCNCscEVUtHwJDNicoIJChYXGBkaJSYnKCkqNDU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6g4SFhoeIiYqSk5SVlpeYmZqio6Slpqeoqaqys7S1tre4ubrCw8TFxsfIycrS09TV1tfY2drh4uPk5ebn6Onq8fLz9PX29/j5+v/EAB8BAAMBAQEBAQEBAQEAAAAAAAABAgMEBQYHCAkKC//EALURAAIBAgQEAwQHBQQEAAECdwABAgMRBAUhMQYSQVEHYXETIjKBCBRCkaGxwQkjM1LwFWJy0QoWJDThJfEXGBkaJicoKSo1Njc4OTpDREVGR0hJSlNUVVZXWFlaY2RlZmdoaWpzdHV2d3h5eoKDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uLj5OXm5+jp6vLz9PX29/j5+v/bAEMAAQEBAQEBAgEBAgICAgICAwICAgIDBAMDAwMDBAUEBAQEBAQFBQUFBQUFBQYGBgYGBgcHBwcHCAgICAgICAgICP/bAEMBAQEBAgICAwICAwgFBAUICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICP/dAAQAAf/aAAwDAQACEQMRAD8A/v4ooooA/9k=" + }, + }, + ], + }, + { + "role": "assistant", + "content": "The number of pixels in the image cannot be determined from the provided information.", + }, + {"role": "user", "content": "Alright"}, + ] + + expected_outputs = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "How many pixels are in the image?"}, + {"type": "image", "url": "/var/folders/4v/64sxdhsd3gz3r8vhhnyc0mqw0000gn/T/tmp50oyghk6.png"}, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "The number of pixels in the image cannot be determined from the provided information.", + } + ], + }, + {"role": "user", "content": [{"type": "text", "text": "Alright"}]}, + ] + + outputs = ServeCommand.get_processor_inputs_from_inbound_messages(messages, modality) + + for expected_output, output in zip(expected_outputs, outputs): + expected_output_content = expected_output["content"] + output_content = output["content"] + + self.assertEqual(type(expected_output_content), type(output_content)) + + if isinstance(expected_output_content, list): + for expected_output_content_item, output_content_item in zip(expected_output_content, output_content): + self.assertIn("type", expected_output_content_item) + self.assertIn("type", output_content_item) + self.assertTrue(expected_output_content_item["type"] == output_content_item["type"]) + + if expected_output_content_item["type"] == "text": + self.assertEqual(expected_output_content_item["text"], output_content_item["text"]) + + if expected_output_content_item["type"] == "image": + self.assertTrue(os.path.exists(output_content_item["url"])) + else: + raise ValueError("VLMs should only receive content as lists.") + + @slow # server startup time is slow on our push CI @require_openai class ServeCompletionsGenerateIntegrationTest(ServeCompletionsMixin, unittest.TestCase):