diff --git a/docs/source/en/serving.md b/docs/source/en/serving.md index 5fcd5d1203..5f73f7e136 100644 --- a/docs/source/en/serving.md +++ b/docs/source/en/serving.md @@ -71,7 +71,7 @@ vllm serve Qwen/Qwen2.5-1.5B-Instruct \ > This section is experimental and subject to change in future versions -You can serve LLMs supported by `transformers` with the `transformers serve` CLI. It spawns a local server that offers a chat Completions API compatible with the OpenAI SDK, which is the _de facto_ standard for LLM conversations. This way, you can use the server from many third party applications, or test it using the `transformers chat` CLI ([docs](conversations.md#chat-cli)). +You can serve LLMs supported by `transformers` with the `transformers serve` CLI. It spawns a local server that offers a Chat Completion API or a Response API compatible with the OpenAI SDK, which are the _de facto_ standard for LLM conversations. This way, you can use the server from many third party applications, or test it using the `transformers chat` CLI ([docs](conversations.md#chat-cli)). To launch a server, simply use the `transformers serve` CLI command: diff --git a/setup.py b/setup.py index ff84d79364..75e25e45be 100644 --- a/setup.py +++ b/setup.py @@ -137,6 +137,7 @@ _deps = [ "onnxconverter-common", "onnxruntime-tools>=1.4.2", "onnxruntime>=1.4.0", + "openai", "opencv-python", "optimum-benchmark>=0.3.0", "optuna", @@ -314,7 +315,7 @@ extras["hub-kernels"] = deps_list("kernels") extras["integrations"] = extras["hub-kernels"] + extras["optuna"] + extras["ray"] + extras["sigopt"] -extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette") + extras["torch"] +extras["serving"] = deps_list("openai", "pydantic", "uvicorn", "fastapi", "starlette") + extras["torch"] extras["audio"] = deps_list( "librosa", "pyctcdecode", diff --git a/src/transformers/commands/chat.py b/src/transformers/commands/chat.py index e74970f694..81a01932a0 100644 --- a/src/transformers/commands/chat.py +++ b/src/transformers/commands/chat.py @@ -471,7 +471,7 @@ class ChatCommand(BaseTransformersCLICommand): # This is a chat session, so we have a few non-standard defaults # !!!!!!!!! generation_config = copy.deepcopy(model_generation_config) - generation_config.update({"do_sample": True, "max_new_tokens": 256}) + generation_config.update(**{"do_sample": True, "max_new_tokens": 256}) # Finally: parse and apply `generate_flags` parsed_generate_flags = self.parse_generate_flags(args.generate_flags) diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index d79b2b2f3b..bf762b2214 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import copy import functools import json @@ -19,11 +20,16 @@ import time from argparse import ArgumentParser, Namespace from dataclasses import dataclass, field from threading import Thread -from typing import Any, Optional +from typing import Generator, Optional from huggingface_hub import ModelInfo, model_info -from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available +from transformers.utils.import_utils import ( + is_fastapi_available, + is_openai_available, + is_pydantic_available, + is_uvicorn_available, +) from .. import LogitsProcessorList, PreTrainedTokenizerFast, TextIteratorStreamer from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus @@ -42,53 +48,108 @@ if is_torch_available(): PreTrainedModel, ) - -if is_pydantic_available() and is_fastapi_available() and is_uvicorn_available(): +serve_dependencies_available = ( + is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available() +) +if serve_dependencies_available: import uvicorn - from fastapi import FastAPI + from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse - from pydantic import BaseModel + from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, + ) + from openai.types.chat.completion_create_params import CompletionCreateParamsStreaming + from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseCreatedEvent, + ResponseError, + ResponseErrorEvent, + ResponseFailedEvent, + ResponseInProgressEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputMessage, + ResponseOutputText, + ResponseTextDeltaEvent, + ResponseTextDoneEvent, + ) + from openai.types.responses.response_create_params import ResponseCreateParamsStreaming + from pydantic import BaseModel, TypeAdapter, ValidationError - class Message(BaseModel): - role: str - content: str + # Expand OpenAI's request input types with an optional `generation_config` field + class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False): + """ + OpenAI's ResponseCreateParamsStreaming with an additional field for the generation config (as a json string). + """ - class ChatCompletionInput(BaseModel): - messages: list[Message] + generation_config: Optional[str] - stream: Optional[bool] = False - model: Optional[str] = None - request_id: Optional[str] = None - extra_body: Optional[dict] = None - frequency_penalty: Optional[float] = None - logit_bias: Optional[list[float]] = None - max_tokens: Optional[int] = None - stop: Optional[list[str]] = None - temperature: Optional[float] = None - top_p: Optional[float] = None - seed: Optional[int] = None + class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False): + """ + OpenAI's CompletionCreateParamsStreaming with additional fields for the generation config (as a json string) + and the request ID to re-use the previous KV cache. + """ - # Additional options supported by the HFH InferenceClient - # that aren't yet supported here. + generation_config: Optional[str] + request_id: Optional[str] - # logprobs: Optional[bool] = None - tools: Any = None - # n: Optional[int] = None - # presence_penalty: Optional[float] = None - # response_format: Optional[ChatCompletionInputGrammarType] = None - # stream_options: Optional[ChatCompletionInputStreamOptions] = None - # tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None - # tool_prompt: Optional[str] = None - # top_logprobs: Optional[int] = None + # Contrarily to OpenAI's output types, input types are `TypedDict`, which don't have validation + response_validator = TypeAdapter(TransformersResponseCreateParamsStreaming) + completion_validator = TypeAdapter(TransformersCompletionCreateParamsStreaming) - # transformers-specific request fields - generation_config: Optional[str] = None + # Define request fields that are not yet used in `transformers serve`. Receiving these fields will raise an + # HTTPException. + UNUSED_RESPONSE_FIELDS = { + "background", + "include", + "max_tool_calls", + "previous_response_id", + "prompt", + "reasoning", + "service_tier", + "store", + "text", + "tool_choice", + "top_logprobs", + "truncation", + "user", + } + + UNUSED_CHAT_COMPLETION_FIELDS = { + "audio", + "function_call", + "functions", + "logprobs", + "max_completion_tokens", + "metadata", + "modalities", + "n", + "parallel_tool_calls", + "prediction", + "presence_penalty", + "reasoning_effort", + "response_format", + "service_tier", + "stop", + "store", + "stream_options", + "tool_choice", + "top_logprobs", + "user", + "web_search_options", + } logger = logging.get_logger(__name__) - # Possible tokens that indicate the start/end of a tool call # TODO (joao, matt): streamline tool token detection logic _TOOL_CALL_TOKENS = { @@ -110,7 +171,9 @@ def serve_command_factory(args: Namespace): def create_generation_config_from_req( - req: "ChatCompletionInput", model_generation_config: "GenerationConfig", **kwargs + req: dict, + model_generation_config: "GenerationConfig", + **kwargs, ) -> "GenerationConfig": """ Creates a generation config from the parameters of the request. If a generation config is passed in the request, @@ -118,18 +181,20 @@ def create_generation_config_from_req( Other parameters in the request will be applied on top of the baseline. Args: - req (`ChatCompletionInput`): + req (`dict`): The request which may optionally contain generation parameters. model_generation_config (`GenerationConfig`): The model's default generation config. + kwargs (`dict`): + Additional parameters to set in the generation config. Returns: The prepared `GenerationConfig` object. """ # If there is a generation config in the request, it is a json string serialization from a `GenerationConfig` # object. For simplicity, flags set here take precedence over all other flags. - if req.generation_config is not None: - generation_config = GenerationConfig(**json.loads(req.generation_config)) + if req.get("generation_config") is not None: + generation_config = GenerationConfig(**json.loads(req["generation_config"])) else: generation_config = copy.deepcopy(model_generation_config) @@ -139,20 +204,31 @@ def create_generation_config_from_req( if v is not None: setattr(generation_config, k, v) - if req.frequency_penalty is not None: - generation_config.repetition_penalty = float(req.frequency_penalty) - if req.logit_bias is not None: - generation_config.sequence_bias = req.logit_bias - if req.stop is not None: - generation_config.stop_strings = req.stop - if req.temperature is not None: - generation_config.temperature = float(req.temperature) - if float(req.temperature) == 0.0: + # Response-specific parameters + if req.get("max_output_tokens") is not None: + generation_config.max_new_tokens = int(req["max_output_tokens"]) + + # Completion-specific parameters + if req.get("max_tokens") is not None: + generation_config.max_new_tokens = int(req["max_tokens"]) + if req.get("frequency_penalty") is not None: + generation_config.repetition_penalty = float(req["frequency_penalty"]) + if req.get("logit_bias") is not None: + generation_config.sequence_bias = req["logit_bias"] + if req.get("stop") is not None: + generation_config.stop_strings = req["stop"] + if req.get("temperature") is not None: + generation_config.temperature = float(req["temperature"]) + if float(req["temperature"]) == 0.0: generation_config.do_sample = False - if req.top_p is not None: - generation_config.top_p = float(req.top_p) - if req.seed is not None: - torch.manual_seed(req.seed) + if req.get("top_p") is not None: + generation_config.top_p = float(req["top_p"]) + if req.get("seed") is not None: + torch.manual_seed(req["seed"]) + + # Sets server-specific defaults, if unset + if generation_config.max_new_tokens is None: + generation_config.max_new_tokens = 1024 return generation_config @@ -228,14 +304,28 @@ class ServeArguments: }, ) + # TODO + # Testing + # As of 2025-07-11, testing on https://github.com/openai/openai-responses-starter-app/, validation on the + # Response input is failing. The app works well without validation. Enable at some point in the future. + input_validation: bool = field( + default=False, + metadata={ + "help": ("Whether to turn on strict input validation."), + }, + ) + force_model: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Name of the model to be forced on all requests. This is useful for testing Apps that don't allow " + "changing models in the request." + ), + }, + ) + class ServeCommand(BaseTransformersCLICommand): - loaded_model: Optional[str] = None - running_continuous_batching_manager: Optional[ContinuousBatchingManager] = None - - model: PreTrainedModel - tokenizer: PreTrainedTokenizerFast - @staticmethod def register_subcommand(parser: ArgumentParser): """ @@ -249,39 +339,113 @@ class ServeCommand(BaseTransformersCLICommand): serve_parser.set_defaults(func=serve_command_factory) def __init__(self, args: ServeArguments): - if not is_pydantic_available() or not is_fastapi_available() or not is_uvicorn_available(): + if not serve_dependencies_available: raise ImportError( "Missing dependencies for the serving CLI. Please install with `pip install transformers[serving]`" ) + # Store and process input arguments self.args = args self.use_continuous_batching = self.args.attn_implementation == "sdpa_paged" self.enable_cors = self.args.enable_cors - # State: preserves information about the last call and last KV cache, to determine whether we can reuse the KV - # cache and avoid re-running prefil - self.last_messages = None - self.last_kv_cache = None - + # Set up logging transformers_logger = logging.get_logger("transformers") transformers_logger.setLevel(logging.log_levels[self.args.log_level.lower()]) cb_logger = logging.get_logger("transformers.generation.continuous_batching") cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()]) - def build_chunk( + # Internal state: + # 1. Tracks the most recently used model, to prevent reloading the model unnecessarily + self.loaded_model: Optional[str] = None + self.running_continuous_batching_manager: Optional[ContinuousBatchingManager] = None + self.model: PreTrainedModel + self.tokenizer: PreTrainedTokenizerFast + + # 2. preserves information about the last call and last KV cache, to determine whether we can reuse the KV + # cache and avoid re-running prefil + self.last_messages = None + self.last_kv_cache = None + + def _validate_request( self, - request_id: str, + request: dict, + schema: "_TypedDictMeta", # noqa: F821 + validator: "TypeAdapter", + unused_fields: set, + ): + """ + Validates the request against the schema, and checks for unexpected keys. + + Args: + request (`dict`): + The request to validate. + schema (`_TypedDictMeta`): + The schema of the request to validate. It is a `TypedDict` definition. + validator (`TypeAdapter`): + The validator to use to validate the request. Built from `schema`. + unused_fields (`set`): + Fields accepted by `schema`, but not used in `transformers serve`. + + Raises: + HTTPException: If the request is invalid or contains unexpected or unused fields. + """ + logger.debug(f"Validating request: {request}") + + # Validate unexpected keys -- Pydantic doesn't validate extra keys in the request. + input_keys = set(request.keys()) + possible_keys = schema.__mutable_keys__ + unexpected_keys = input_keys - possible_keys + if unexpected_keys: + logger.error(f"Unexpected keys in the request: {unexpected_keys}") + raise HTTPException(status_code=422, detail=f"Unexpected keys in the request: {unexpected_keys}") + + if self.args.input_validation: + # Validate expected keys + try: + validator.validate_python(request) + except ValidationError as e: + logger.error(f"Validation error: {e.errors()}") + raise HTTPException(status_code=422, detail=e.errors()) + + # Validate unused fields + unused_fields_in_request = input_keys & unused_fields + if unused_fields_in_request: + logger.error(f"Unused fields in the request: {unused_fields_in_request}") + raise HTTPException( + status_code=422, detail=f"Unused fields in the request: {unused_fields_in_request}" + ) + + def validate_response_request(self, request: dict): + self._validate_request( + request=request, + schema=TransformersResponseCreateParamsStreaming, + validator=response_validator, + unused_fields=UNUSED_RESPONSE_FIELDS, + ) + + def validate_chat_completion_request(self, request: dict): + self._validate_request( + request=request, + schema=TransformersCompletionCreateParamsStreaming, + validator=completion_validator, + unused_fields=UNUSED_CHAT_COMPLETION_FIELDS, + ) + + def build_chat_completion_chunk( + self, + request_id: Optional[str] = "", content: Optional[str] = None, role: Optional[str] = None, finish_reason: Optional[str] = None, - tool_calls: Optional[list[dict]] = None, + tool_calls: Optional[list["ChoiceDeltaToolCall"]] = None, ) -> str: """ - Builds a chunk of a streaming response. + Builds a chunk of a streaming OpenAI Chat Completion response. - IMPORTANT: The built chunk won't contain empty fields (fields with `None`). Some downstream apps, like Cursor, - assume that when the field exists, it has data. + IMPORTANT: The serialized chunk won't contain empty fields (fields with `None`). Some downstream apps, + like Cursor, assume that when the field exists, it has data. Args: request_id (`str`): @@ -292,30 +456,47 @@ class ServeCommand(BaseTransformersCLICommand): The role of the next content, until a new role is defined. finish_reason (`str`, *optional*): The reason the generation by the model has finished. - tool_calls (`list[dict]`, *optional*): + tool_calls (`list[ChoiceDeltaToolCall]`, *optional*): Data about the tool calls, when they are triggered. Returns: `str`: The built chunk, a string containing a JSON string with the payload. """ - payload = { - "object": "chat.completion.chunk", - "id": request_id, - "created": int(time.time()), - "model": self.loaded_model, - "choices": [{"delta": {}, "index": 0}], - "system_fingerprint": "", - } - if content is not None: - payload["choices"][0]["delta"]["content"] = content - if role is not None: - payload["choices"][0]["delta"]["role"] = role - if tool_calls is not None: - payload["choices"][0]["delta"]["tool_calls"] = tool_calls - if finish_reason is not None: - payload["choices"][0]["finish_reason"] = finish_reason + chunk = ChatCompletionChunk( + id=request_id, + created=int(time.time()), + model=self.loaded_model, + choices=[ + Choice( + delta=ChoiceDelta( + content=content, + role=role, + tool_calls=tool_calls, + ), + index=0, + finish_reason=finish_reason, + ) + ], + system_fingerprint="", + object="chat.completion.chunk", + ) + return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" - return f"data: {json.dumps(payload)}\n\n" + def build_response_event(self, response: "BaseModel") -> str: + """ + Builds a event of a streaming OpenAI Response response. + + IMPORTANT: The serialized chunk won't contain empty fields (fields with `None`). Some downstream apps, + like Cursor, assume that when the field exists, it has data. + + Args: + response (`BaseModel`): + The response to build an event from. One of the multiple OpenAI Response output types + + Returns: + `str`: The built chunk, a string containing a JSON string with the payload. + """ + return f"data: {response.model_dump_json(exclude_none=True)}\n\n" def run(self): app = FastAPI() @@ -331,31 +512,22 @@ class ServeCommand(BaseTransformersCLICommand): allow_headers=["*"], ) - if self.use_continuous_batching: - self.continuous_batching(app) - else: - self.generate(app) + @app.post("/v1/chat/completions") + def chat_completion(request: dict): + self.validate_chat_completion_request(request=request) - @functools.lru_cache(maxsize=None) - def get_text_gen_models() -> list[ModelInfo]: - """ - This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based - model working with generate can work. + if self.use_continuous_batching: + output = self.continuous_batching_chat_completion(request) + else: + output = self.generate_chat_completion(request) + return StreamingResponse(output, media_type="text/event-stream") - This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party - integrations. - """ - return [ - model_info("Menlo/Jan-nano"), - model_info("Menlo/Jan-nano-128k"), - model_info("Qwen/Qwen2.5-0.5B-Instruct"), - model_info("Qwen/Qwen2.5-3B-Instruct"), - model_info("Qwen/Qwen2.5-7B-Instruct"), - model_info("Qwen/Qwen2.5-14B-Instruct"), - model_info("meta-llama/Llama-3.1-8B-Instruct"), - model_info("meta-llama/Llama-3.2-1B-Instruct"), - model_info("meta-llama/Llama-3.3-70B-Instruct"), - ] + @app.post("/v1/responses") + def responses(request: dict): + self.validate_response_request(request=request) + + output = self.generate_response(request) + return StreamingResponse(output, media_type="text/event-stream") @app.get("/v1/models") def get_all_models(): @@ -369,284 +541,565 @@ class ServeCommand(BaseTransformersCLICommand): "created": model.created_at.timestamp(), "owned_by": model.author, } - for model in get_text_gen_models() + for model in self.get_text_gen_models() ], } ) uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level) - def continuous_batching(self, app): - @app.post("/v1/chat/completions") - def _serve(req: "ChatCompletionInput"): - if not req.stream: - return {"error": "Only streaming mode is supported."} + @functools.lru_cache(maxsize=None) + def get_text_gen_models(self) -> list[ModelInfo]: + """ + This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based + model working with generate can work. - update_model = self.canonicalized_model_name(req.model) != self.loaded_model - if update_model: - self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) + This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party + integrations. + """ + return [ + model_info("Menlo/Jan-nano"), + model_info("Menlo/Jan-nano-128k"), + model_info("Qwen/Qwen2.5-0.5B-Instruct"), + model_info("Qwen/Qwen2.5-3B-Instruct"), + model_info("Qwen/Qwen2.5-7B-Instruct"), + model_info("Qwen/Qwen2.5-14B-Instruct"), + model_info("meta-llama/Llama-3.1-8B-Instruct"), + model_info("meta-llama/Llama-3.2-1B-Instruct"), + model_info("meta-llama/Llama-3.3-70B-Instruct"), + ] - generation_config = create_generation_config_from_req( - req, - model_generation_config=self.model.generation_config, - eos_token_id=self.tokenizer.eos_token_id, - pad_token_id=self.tokenizer.pad_token_id, - use_cache=False, - num_blocks=1, - block_size=1024, - do_sample=False, - max_batch_tokens=10, - scheduler="fifo", + def continuous_batching_chat_completion(self, req: dict) -> Generator[str, None, None]: + """ + Generates an OpenAI Chat Completion using continuous batching. + + Args: + req (`dict`): The request to generate an OpenAI Chat Completion for. + + Returns: + `Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks. + """ + if self.args.force_model is not None: + req["model"] = self.args.force_model + + update_model = self.canonicalized_model_name(req["model"]) != self.loaded_model + if update_model: + # When switching models, terminate a continuous batching manager if it is running. + if self.running_continuous_batching_manager is not None: + self.running_continuous_batching_manager.stop(block=True, timeout=2) + self.running_continuous_batching_manager = None + self.load_model_and_tokenizer(req["model"], self.args) + + generation_config = create_generation_config_from_req( + req, + model_generation_config=self.model.generation_config, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=False, + num_blocks=1, + block_size=1024, + do_sample=False, + max_batch_tokens=10, + scheduler="fifo", + ) + + if self.running_continuous_batching_manager is None: + self.running_continuous_batching_manager = self.model.init_continuous_batching( + generation_config=generation_config, streaming=True ) - if self.running_continuous_batching_manager is None or update_model: - self.running_continuous_batching_manager = self.model.init_continuous_batching( - generation_config=generation_config, streaming=True + # TODO (Joao, Lysandre): the logits processors should be fixed in continuous batching + # and correctly applied in non-cb + self.running_continuous_batching_manager.logit_processor = LogitsProcessorList() + self.running_continuous_batching_manager.start() + + # TODO (Joao, Lysandre): this should also work with tool support + inputs = self.tokenizer.apply_chat_template( + req["messages"], return_tensors="pt", add_generation_prompt=True + ).to(self.model.device) + + def stream_chat_completion(_inputs): + try: + request_id = self.running_continuous_batching_manager.add_request( + _inputs, request_id=req.get("request_id"), max_new_tokens=generation_config.max_new_tokens ) - # TODO (Joao, Lysandre): the logits processors should be fixed in continuous batching - # and correctly applied in non-cb - self.running_continuous_batching_manager.logit_processor = LogitsProcessorList() - self.running_continuous_batching_manager.start() + queue_is_flushed = False - # TODO (Joao, Lysandre): this should also work with tool support - inputs = self.tokenizer.apply_chat_template( - req.messages, return_tensors="pt", add_generation_prompt=True - ).to(self.model.device) + # Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit + # they come from the assistant. + yield self.build_chat_completion_chunk(request_id, role="assistant") - def stream_response(_inputs): - try: - max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 1024 - request_id = self.running_continuous_batching_manager.add_request( - _inputs, request_id=req.request_id, max_new_tokens=max_new_tokens - ) - queue_is_flushed = False + for result in self.running_continuous_batching_manager: + if result.request_id != request_id: + continue + if req.get("request_id") is not None and not queue_is_flushed: + if result.status == RequestStatus.FINISHED: + continue + else: + queue_is_flushed = True - # Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit - # they come from the assistant. - yield self.build_chunk(request_id, role="assistant") + finish_reason = "stop" if result.status == RequestStatus.FINISHED else None + if result.status == RequestStatus.FINISHED: + yield self.build_chat_completion_chunk(request_id, finish_reason=finish_reason) + break + else: + yield self.build_chat_completion_chunk(request_id=request_id, content=result.next_token) - for result in self.running_continuous_batching_manager: - if result.request_id != request_id: + except Exception as e: + logger.error(str(e)) + yield f'data: {{"error": "{str(e)}"}}' + + return stream_chat_completion(inputs[0]) + + def generate_chat_completion(self, req: dict) -> Generator[str, None, None]: + """ + Generates an OpenAI Chat Completion using `generate`. + + Args: + req (`dict`): The request to generate an OpenAI Chat Completion for. + + Returns: + `Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks. + """ + if self.args.force_model is not None: + req["model"] = self.args.force_model + + update_model = self.canonicalized_model_name(req["model"]) != self.loaded_model + if update_model: + self.load_model_and_tokenizer(req["model"], self.args) + + # HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a + # request whose last message is from the assistant. + if req["messages"][-1]["role"] == "assistant": + return + + # ====== TOOL PREPROCESSING LOGIC ====== + tool_model_family = None + for supported_model_families in _MODELS_WITH_TOOL_SUPPORT: + if supported_model_families in self.model.config.architectures[0].lower(): + tool_model_family = supported_model_families + break + # TODO: trigger 2 constrained generations after the tool call start token is emitted: + # 1. force generation to pick from the tool names + # 2. force generation to pick from that tool's arguments + # ====== END OF TOOL PREPROCESSING LOGIC ====== + + if tool_model_family is not None: + text = self.tokenizer.apply_chat_template( + req["messages"], add_generation_prompt=True, tokenize=False, tools=req.get("tools") + ) + else: + text = self.tokenizer.apply_chat_template(req["messages"], add_generation_prompt=True, tokenize=False) + inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"] + request_id = req.get("request_id", "req_0") + + generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True) + generation_config = create_generation_config_from_req( + req, model_generation_config=self.model.generation_config + ) + + last_kv_cache = None + if self.is_continuation(req) and not update_model: + last_kv_cache = self.last_kv_cache + + generation_kwargs = { + "inputs": inputs, + "attention_mask": torch.ones_like(inputs), + "streamer": generation_streamer, + "generation_config": generation_config, + "return_dict_in_generate": True, + "past_key_values": last_kv_cache, + } + + def stream_chat_completion(streamer, _request_id): + # Thin wrapper to save the KV cache after generation + def generate_with_cache(**kwargs): + generate_output = self.model.generate(**kwargs) + self.last_kv_cache = generate_output.past_key_values + + thread = Thread(target=generate_with_cache, kwargs=generation_kwargs) + + try: + thread.start() + tool_state = ToolState() + + # Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit + # they come from the assistant. + yield self.build_chat_completion_chunk(request_id, role="assistant") + + for result in streamer: + # ====== TOOL CALL LOGIC ====== + if tool_model_family is not None: + # Start of a tool call: reset state variables, set `inside_tool_call` + if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]: + tool_state.inside_tool_call = True continue - if req.request_id is not None and not queue_is_flushed: - if result.status == RequestStatus.FINISHED: - continue + # End of tool call: reset `inside_tool_call`, emit a `finish_reason` + if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]: + tool_state.reset() + yield self.build_chat_completion_chunk( + request_id=_request_id, role=None, finish_reason="tool_calls" + ) + continue + + # Inside a tool call + if tool_state.inside_tool_call: + tool_state.buffer += result + + # First step: extract the tool name (may need several tokens, and we can't emit a delta + # until we have the full name) + if not tool_state.has_tool_name_defined: + tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer) + if tool_name is None: + continue + else: + tool_name = tool_name.group(1) + tool_state.has_tool_name_defined = True + tool = ChoiceDeltaToolCall( + function=ChoiceDeltaToolCallFunction(name=tool_name), + index=0, + type="function", + id=_request_id + "_tool_call", # Only the first tool call delta has an id + ) + + # Second step: extract tool arguments. The tool arguments can be seen as a json string + # within the tool json string. We emit a delta for the arguments. else: - queue_is_flushed = True + # Empty text: skip + if result == "": + continue + # Until we see the `"arguments": {` in the buffer, we skip + # TODO: other models will likely need more elaborate processing here + if '"arguments": {' not in tool_state.buffer: + continue - finish_reason = "stop" if result.status == RequestStatus.FINISHED else None - if result.status == RequestStatus.FINISHED: - yield self.build_chunk(request_id, finish_reason=finish_reason) - break - else: - yield self.build_chunk(request_id=request_id, content=result.next_token) + # Handle nesting. We want to exclude the last } from the emitted arguments (it's + # closing the outermost nesting level, outside the arguments block) + tool_state.arg_nesting_level += result.count("{") + tool_state.arg_nesting_level -= result.count("}") + if tool_state.arg_nesting_level < 0: + result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}" - except Exception as e: - logger.error(str(e)) - yield f'data: {{"error": "{str(e)}"}}' + tool = ChoiceDeltaToolCall( + function=ChoiceDeltaToolCallFunction(arguments=result), + index=0, + type="function", + ) - return StreamingResponse(stream_response(inputs[0]), media_type="text/event-stream") + yield self.build_chat_completion_chunk( + request_id=_request_id, role=None, tool_calls=[tool] + ) + continue + # ====== END OF TOOL CALL LOGIC ====== - def is_continuation(self, req: "ChatCompletionInput") -> bool: + # All non-tool related tokens are emitted as assistant messages. Empty text is skipped. + if result != "": + yield self.build_chat_completion_chunk(_request_id, content=result) + yield self.build_chat_completion_chunk(_request_id, finish_reason="stop") + + thread.join() + except Exception as e: + logger.error(str(e)) + yield f'data: {{"error": "{str(e)}"}}' + + finally: + thread.join() + + return stream_chat_completion(generation_streamer, request_id) + + def generate_response(self, req: dict) -> Generator[str, None, None]: + """ + Generates an OpenAI Response using `generate`. + + Args: + req (`dict`): The request to generate an OpenAI Response for. + + Returns: + `Generator[str, None, None]`: A generator that yields the OpenAI Response events. + """ + # TODO -- Implement non-streaming mode + if self.args.force_model is not None: + req["model"] = self.args.force_model + + update_model = self.canonicalized_model_name(req["model"]) != self.loaded_model + if update_model: + self.load_model_and_tokenizer(req["model"], self.args) + + text = self.tokenizer.apply_chat_template(req["input"], add_generation_prompt=True, tokenize=False) + inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"] + request_id = req.get("previous_response_id", "req_0") + + generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True) + generation_config = create_generation_config_from_req( + req, model_generation_config=self.model.generation_config + ) + + last_kv_cache = None + if self.is_continuation(req) and not update_model: + last_kv_cache = self.last_kv_cache + + generation_kwargs = { + "inputs": inputs, + "attention_mask": torch.ones_like(inputs), + "streamer": generation_streamer, + "generation_config": generation_config, + "return_dict_in_generate": True, + "past_key_values": last_kv_cache, + } + + def stream_response(streamer, _request_id): + thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + sequence_number = 0 + output_index = 0 + content_index = 0 + + try: + thread.start() + created_at = time.time() # the spec expects a unix timestamp in seconds + + # We start by acknowledging the request (the request has `status="queued"`), and then by moving it to + # in progress (`status="in_progress"`) + response_created = ResponseCreatedEvent( + type="response.created", + sequence_number=sequence_number, + response=Response( + id=f"resp_{request_id}", + created_at=created_at, + status="queued", + model=self.loaded_model, + instructions=req.get("instructions"), + text={"format": {"type": "text"}}, + object="response", + tools=[], + output=[], + parallel_tool_calls=req.get("parallel_tool_calls", False), + tool_choice="auto", + metadata=req.get("metadata"), + ), + ) + sequence_number += 1 + yield self.build_response_event(response_created) + + response_in_progress = ResponseInProgressEvent( + type="response.in_progress", + sequence_number=sequence_number, + response=Response( + id=f"resp_{request_id}", + created_at=created_at, + status="in_progress", + model=self.loaded_model, + instructions=req.get("instructions"), + text={"format": {"type": "text"}}, + object="response", + tools=[], + output=[], + parallel_tool_calls=req.get("parallel_tool_calls", False), + tool_choice="auto", + metadata=req.get("metadata"), + ), + ) + sequence_number += 1 + yield self.build_response_event(response_in_progress) + + # Start the output item. Emit the assistant role to start the stream. Other chunks won't have a role, + # as it is implicit + response_output_item_added = ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=sequence_number, + output_index=output_index, + item=ResponseOutputMessage( + id=f"msg_{request_id}", type="message", status="in_progress", role="assistant", content=[] + ), + ) + sequence_number += 1 + yield self.build_response_event(response_output_item_added) + + # Start the content part of the event + response_content_part_added = ResponseContentPartAddedEvent( + type="response.content_part.added", + item_id=f"msg_{request_id}", + sequence_number=sequence_number, + output_index=output_index, + content_index=content_index, + part=ResponseOutputText(type="output_text", text="", annotations=[]), + ) + sequence_number += 1 + yield self.build_response_event(response_content_part_added) + + # Stream the actual generated text + results = "" + for result in streamer: + results += result + response_output_text_delta = ResponseTextDeltaEvent( + type="response.output_text.delta", + item_id=f"msg_{request_id}", + sequence_number=sequence_number, + output_index=output_index, + content_index=content_index, + delta=result, + ) + sequence_number += 1 + yield self.build_response_event(response_output_text_delta) + + # Signal the end of the text generation + response_output_text_done = ResponseTextDoneEvent( + type="response.output_text.done", + item_id=f"msg_{request_id}", + sequence_number=sequence_number, + output_index=output_index, + content_index=0, + text=results, + ) + sequence_number += 1 + yield self.build_response_event(response_output_text_done) + + # Complete the content part + response_content_part_done = ResponseContentPartDoneEvent( + type="response.content_part.done", + item_id=f"msg_{request_id}", + sequence_number=sequence_number, + output_index=output_index, + content_index=content_index, + part=ResponseOutputText(type="output_text", text=response_output_text_done.text, annotations=[]), + ) + sequence_number += 1 + content_index += 1 + yield self.build_response_event(response_content_part_done) + + # Complete the output item + response_output_item_done = ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=sequence_number, + output_index=output_index, + item=ResponseOutputMessage( + id=f"msg_{request_id}", + type="message", + status="completed", + role="assistant", + content=[response_content_part_done.part], + annotations=[], + ), + ) + sequence_number += 1 + output_index += 1 + yield self.build_response_event(response_output_item_done) + + # Finally, Complete the event + response_completed = ResponseCompletedEvent( + type="response.completed", + sequence_number=sequence_number, + response=Response( + id=f"resp_{request_id}", + created_at=created_at, + status="completed", + model=self.loaded_model, + instructions=req.get("instructions"), + text={"format": {"type": "text"}}, + output=[response_output_item_done.item], + object="response", + tools=[], + parallel_tool_calls=req.get("parallel_tool_calls", False), + tool_choice="auto", + metadata=req.get("metadata"), + ), + ) + sequence_number += 1 + yield self.build_response_event(response_completed) + + thread.join() + except Exception as e: + logger.error(f"Exception in response generation: {str(e)}") + error_event = ResponseErrorEvent( + type="error", + sequence_number=sequence_number, + message=str(e), + ) + sequence_number += 1 + yield self.build_response_event(error_event) + + response_failed = ResponseFailedEvent( + type="response.failed", + sequence_number=sequence_number, + response=Response( + id=f"resp_{request_id}", + created_at=created_at, + status="failed", + model=self.loaded_model, + instructions=req.get("instructions"), + text={"format": {"type": "text"}}, + output=[], + object="response", + tools=[], + parallel_tool_calls=False, + tool_choice="auto", + metadata=req.get("metadata"), + error=ResponseError( + code="server_error", + message=str(e), + ), + ), + ) + sequence_number += 1 + yield self.build_response_event(response_failed) + + finally: + thread.join() + + return stream_response(generation_streamer, request_id) + + def is_continuation(self, req: dict) -> bool: """ Determines whether the current request is a continuation of the last request. In other words, if it is the same chat session. Args: - req (`ChatCompletionInput`): The request to check. + req (`dict`): The request to check. Returns: `True` if the request is a continuation of the last request, `False` otherwise. """ + messages = req.get("messages") or req.get("input") # ChatCompletion and Response have different fields req_continues_last_messages = True # No cached messages: this is a new request if self.last_messages is None: req_continues_last_messages = False # The new request has no new rounds of conversation: this is a new request - elif len(self.last_messages) >= len(req.messages): + elif len(self.last_messages) >= len(messages): req_continues_last_messages = False # Otherwise, check that the last messages are a subset of the new request else: for i in range(len(self.last_messages)): - if self.last_messages[i] != req.messages[i]: + if self.last_messages[i] != messages[i]: req_continues_last_messages = False break - self.last_messages = req.messages + self.last_messages = messages return req_continues_last_messages - def generate(self, app): - @app.post("/v1/chat/completions") - def _serve(req: "ChatCompletionInput"): - logger.debug(f"Received request: {req}") - update_model = self.canonicalized_model_name(req.model) != self.loaded_model - - if update_model: - self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) - - if not req.stream: - return {"error": "Only streaming mode is supported."} - - # HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a - # request whose last message is from the assistant. - if req.messages[-1].role == "assistant": - return - - # ====== TOOL PREPROCESSING LOGIC ====== - tool_model_family = None - for supported_model_families in _MODELS_WITH_TOOL_SUPPORT: - if supported_model_families in self.model.config.architectures[0].lower(): - tool_model_family = supported_model_families - break - # TODO: trigger 2 constrained generations after the tool call start token is emitted: - # 1. force generation to pick from the tool names - # 2. force generation to pick from that tool's arguments - # ====== END OF TOOL PREPROCESSING LOGIC ====== - - if tool_model_family is not None: - text = self.tokenizer.apply_chat_template( - req.messages, add_generation_prompt=True, tokenize=False, tools=req.tools - ) - else: - text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False) - - inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"] - request_id = req.request_id if req.request_id is not None else "req_0" - - generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True) - - generation_config = create_generation_config_from_req( - req, - model_generation_config=self.model.generation_config, - ) - max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 1024 - generation_config.max_new_tokens = max_new_tokens - - last_kv_cache = None - if self.is_continuation(req) and not update_model: - last_kv_cache = self.last_kv_cache - - generation_kwargs = { - "inputs": inputs, - "attention_mask": torch.ones_like(inputs), - "streamer": generation_streamer, - "generation_config": generation_config, - "return_dict_in_generate": True, - "past_key_values": last_kv_cache, - } - - def stream_response(streamer, _request_id): - # Thin wrapper to save the KV cache after generation - def generate_with_cache(**kwargs): - generate_output = self.model.generate(**kwargs) - self.last_kv_cache = generate_output.past_key_values - - thread = Thread(target=generate_with_cache, kwargs=generation_kwargs) - - try: - thread.start() - tool_state = ToolState() - - # Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit - # they come from the assistant. - logger.debug("Starting model output") - yield self.build_chunk(_request_id, role="assistant") - - for result in streamer: - logger.debug(f"Model output: {result}") - - # ====== TOOL CALL LOGIC ====== - if tool_model_family is not None: - # Start of a tool call: reset state variables, set `inside_tool_call` - if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]: - tool_state.inside_tool_call = True - continue - - # End of tool call: reset `inside_tool_call`, emit a `finish_reason` - if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]: - tool_state.reset() - yield self.build_chunk(_request_id, finish_reason="tool_calls") - continue - - # Inside a tool call - if tool_state.inside_tool_call: - tool_state.buffer += result - - # First step: extract the tool name (may need several tokens, and we can't emit a delta - # until we have the full name) - if not tool_state.has_tool_name_defined: - tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer) - if tool_name is None: - continue - else: - tool_name = tool_name.group(1) - tool_state.has_tool_name_defined = True - tool = { - "function": {"name": tool_name}, - "index": 0, - "type": "function", - "id": _request_id + "_tool_call", # Only the first tool call delta has an id - } - - # Second step: extract tool arguments. The tool arguments can be seen as a json string - # within the tool json string. We emit a delta for the arguments. - else: - # Empty text: skip - if result == "": - continue - # Until we see the `"arguments": {` in the buffer, we skip - # TODO: other models will likely need more elaborate processing here - if '"arguments": {' not in tool_state.buffer: - continue - - # Handle nesting. We want to exclude the last } from the emitted arguments (it's - # closing the outermost nesting level, outside the arguments block) - tool_state.arg_nesting_level += result.count("{") - tool_state.arg_nesting_level -= result.count("}") - if tool_state.arg_nesting_level < 0: - result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}" - - tool = { - "function": {"arguments": result}, - "index": 0, - "type": "function", - } - - yield self.build_chunk(_request_id, tool_calls=[tool]) - continue - # ====== END OF TOOL CALL LOGIC ====== - - # All non-tool related tokens are emitted as assistant messages. Empty text is skipped. - if result != "": - yield self.build_chunk(_request_id, content=result) - yield self.build_chunk(_request_id, finish_reason="stop") - - thread.join() - except Exception as e: - logger.error(str(e)) - raise - yield f'data: {{"error": "{str(e)}"}}' - - finally: - thread.join() - - return StreamingResponse(stream_response(generation_streamer, request_id), media_type="text/event-stream") - @staticmethod - def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]: - if model_args.load_in_4bit: + def get_quantization_config(args: ServeArguments) -> Optional["BitsAndBytesConfig"]: + """ + Returns the quantization config for the given CLI arguments. + + Args: + args (`ServeArguments`): The serve arguments. May contain quantization settings, device, etc. + + Returns: + `Optional[BitsAndBytesConfig]`: The quantization config. + """ + if args.load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_4bit=True, # For consistency with model weights, we use the same value as `torch_dtype` - bnb_4bit_compute_dtype=model_args.torch_dtype, - bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, - bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, - bnb_4bit_quant_storage=model_args.torch_dtype, + bnb_4bit_compute_dtype=args.torch_dtype, + bnb_4bit_quant_type=args.bnb_4bit_quant_type, + bnb_4bit_use_double_quant=args.use_bnb_nested_quant, + bnb_4bit_quant_storage=args.torch_dtype, ) - elif model_args.load_in_8bit: + elif args.load_in_8bit: quantization_config = BitsAndBytesConfig( load_in_8bit=True, ) @@ -656,13 +1109,30 @@ class ServeCommand(BaseTransformersCLICommand): return quantization_config def canonicalized_model_name(self, model_id: str) -> str: + """ + Canonicalizes the model name to the format "model_id@revision". If the model_id DOESN'T contain an @, it + defaults to "model_id@main". + + Args: + model_id (`str`): The model ID. + + Returns: + `str`: The canonicalized model name. + """ if "@" in model_id: return model_id return f"{model_id}@main" - def load_model_and_tokenizer( - self, model_id_and_revision: str, args: ServeArguments - ) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]: + def load_model_and_tokenizer(self, model_id_and_revision: str, args: ServeArguments): + """ + Loads the model and tokenizer from the given model ID and revision into the ServeCommand instance. + + Args: + model_id_and_revision (`str`): + The model ID and revision to load. + args (`ServeArguments`): + The serve arguments. May contain quantization settings, device, etc. + """ logger.warning(f"Loading {model_id_and_revision}") if "@" in model_id_and_revision: @@ -699,7 +1169,8 @@ class ServeCommand(BaseTransformersCLICommand): self.loaded_model = f"{model_id}@{revision}" logger.warning(f"Loaded model {self.loaded_model}") - return model, tokenizer + self.model = model + self.tokenizer = tokenizer if __name__ == "__main__": diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index acd8ec7fb4..1f071c6bb6 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -43,6 +43,7 @@ deps = { "onnxconverter-common": "onnxconverter-common", "onnxruntime-tools": "onnxruntime-tools>=1.4.2", "onnxruntime": "onnxruntime>=1.4.0", + "openai": "openai", "opencv-python": "opencv-python", "optimum-benchmark": "optimum-benchmark>=0.3.0", "optuna": "optuna", diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index c582f0e4fb..1df380b6fd 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -112,6 +112,7 @@ from .utils import ( is_natten_available, is_nltk_available, is_onnx_available, + is_openai_available, is_optimum_available, is_optimum_quanto_available, is_pandas_available, @@ -1536,6 +1537,13 @@ def require_speech(test_case): return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case) +def require_openai(test_case): + """ + Decorator marking a test that requires openai + """ + return unittest.skipUnless(is_openai_available(), "test requires openai")(test_case) + + def require_mistral_common(test_case): """ Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available. diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index c20d3d36f5..251c2309ed 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -515,6 +515,10 @@ def is_uvicorn_available(): return _uvicorn_available +def is_openai_available(): + return _openai_available + + def is_pretty_midi_available(): return _pretty_midi_available @@ -730,10 +734,6 @@ def is_onnx_available(): return _onnx_available -def is_openai_available(): - return _openai_available - - def is_flax_available(): return _flax_available @@ -1916,6 +1916,12 @@ UVICORN_IMPORT_ERROR = """ `pip install uvicorn`. Please note that you may need to restart your runtime after installation. """ +# docstyle-ignore +OPENAI_IMPORT_ERROR = """ +{0} requires the openai library but it was not found in your environment. You can install it with pip: +`pip install openai`. Please note that you may need to restart your runtime after installation. +""" + # docstyle-ignore PYTESSERACT_IMPORT_ERROR = """ {0} requires the PyTesseract library but it was not found in your environment. You can install it with pip: @@ -2046,6 +2052,7 @@ BACKENDS_MAPPING = OrderedDict( ("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)), ("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)), ("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)), + ("openai", (is_openai_available, OPENAI_IMPORT_ERROR)), ("mistral-common", (is_mistral_common_available, MISTRAL_COMMON_IMPORT_ERROR)), ] ) diff --git a/tests/commands/test_serving.py b/tests/commands/test_serving.py index 118e4a8be4..403f08ae7e 100644 --- a/tests/commands/test_serving.py +++ b/tests/commands/test_serving.py @@ -24,9 +24,16 @@ from parameterized import parameterized import transformers.commands.transformers_cli as cli from transformers import GenerationConfig from transformers.commands.serving import ServeArguments, ServeCommand -from transformers.testing_utils import CaptureStd, slow +from transformers.testing_utils import CaptureStd, require_openai, slow +from transformers.utils.import_utils import is_openai_available +if is_openai_available(): + from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction + from openai.types.responses import Response, ResponseCreatedEvent + + +@require_openai class ServeCLITest(unittest.TestCase): def test_help(self): """Minimal test: we can invoke the help command.""" @@ -49,36 +56,94 @@ class ServeCLITest(unittest.TestCase): self.assertEqual(parsed_args.host, "0.0.0.0") self.assertEqual(parsed_args.port, 9000) - def test_completions_build_chunk(self): - """Tests that the chunks are correctly built for the Completions API.""" + def test_build_chat_completion_chunk(self): + """ + Tests that the chunks are correctly built for the Chat Completion API. The `choices` checks implictly + confirm that empty fields are not emitted. + """ dummy = ServeCommand.__new__(ServeCommand) dummy.args = type("Args", (), {})() + dummy.loaded_model = "dummy_model@main" + + # The keys for these fields must be present in every chunk + MANDATORY_FIELDS = ["data", "id", "choices", "created", "model", "object", "system_fingerprint"] # Case 1: most fields are provided - chunk = ServeCommand.build_chunk(dummy, request_id="req0", content="hello", finish_reason="stop", role="user") - self.assertIn("chat.completion.chunk", chunk) - self.assertIn("data:", chunk) + chunk = ServeCommand.build_chat_completion_chunk( + dummy, request_id="req0", content="hello", finish_reason="stop", role="user" + ) + for field in MANDATORY_FIELDS: + self.assertIn(field, chunk) self.assertIn( - '"choices": [{"delta": {"content": "hello", "role": "user"}, "index": 0, "finish_reason": "stop"}]', chunk + '"choices":[{"delta":{"content":"hello","role":"user"},"finish_reason":"stop","index":0}]', chunk ) # Case 2: only the role is provided -- other fields in 'choices' are omitted - chunk = ServeCommand.build_chunk(dummy, request_id="req0", role="user") - self.assertIn("chat.completion.chunk", chunk) - self.assertIn("data:", chunk) - self.assertIn('"choices": [{"delta": {"role": "user"}, "index": 0}]', chunk) + chunk = dummy.build_chat_completion_chunk(request_id="req0", role="user") + for field in MANDATORY_FIELDS: + self.assertIn(field, chunk) + self.assertIn('"choices":[{"delta":{"role":"user"},"index":0}]', chunk) # Case 3: only the content is provided -- other fields in 'choices' are omitted - chunk = ServeCommand.build_chunk(dummy, request_id="req0", content="hello") - self.assertIn("chat.completion.chunk", chunk) - self.assertIn("data:", chunk) - self.assertIn('"choices": [{"delta": {"content": "hello"}, "index": 0}]', chunk) + chunk = dummy.build_chat_completion_chunk(request_id="req0", content="hello") + for field in MANDATORY_FIELDS: + self.assertIn(field, chunk) + self.assertIn('"choices":[{"delta":{"content":"hello"},"index":0}]', chunk) - # Case 4: tool calls support a list of nested dictionaries - chunk = ServeCommand.build_chunk(dummy, request_id="req0", tool_calls=[{"foo1": "bar1", "foo2": "bar2"}]) - self.assertIn("chat.completion.chunk", chunk) - self.assertIn("data:", chunk) - self.assertIn('"choices": [{"delta": {"tool_calls": [{"foo1": "bar1", "foo2": "bar2"}]}, "index": 0}]', chunk) + # Case 4: tool calls support a list of ChoiceDeltaToolCall objects + tool_call = ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction(name="foo_bar", arguments='{"foo1": "bar1", "foo2": "bar2"}'), + type="function", + ) + chunk = dummy.build_chat_completion_chunk(request_id="req0", tool_calls=[tool_call]) + for field in MANDATORY_FIELDS: + self.assertIn(field, chunk) + expected_choices_content = ( + 'choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"foo1\\": \\"bar1\\", ' + '\\"foo2\\": \\"bar2\\"}","name":"foo_bar"},"type":"function"}]},"index":0}]' + ) + self.assertIn(expected_choices_content, chunk) + + def test_build_response_event(self): + """ + Tests that the events are correctly built for the Response API. + + Contrarily to the Chat Completion API, the Response API has a wide set of possible output objects. This test + only checks a few basic assumptions -- we rely on OpenAI's pydantic models to enforce the correct schema. + """ + dummy = ServeCommand.__new__(ServeCommand) + dummy.args = type("Args", (), {})() + + response_created = ResponseCreatedEvent( + type="response.created", + sequence_number=0, + response=Response( + id="resp_0", + created_at=time.time(), + status="queued", + model="dummy_model@main", + instructions=None, # <--- is set to None = should NOT be in the output. + text={"format": {"type": "text"}}, + object="response", + tools=[], # <--- empty lists should be in the output (they are often mandatory fields) + output=[], + parallel_tool_calls=False, + tool_choice="auto", + metadata=None, + ), + ) + + event = dummy.build_response_event(response_created) + self.assertTrue(event.startswith("data: ")) # Sanity check: event formatting + self.assertIn('"model":"dummy_model@main"', event) # Sanity check: set field + self.assertIn('"status":"queued"', event) + self.assertIn("tools", event) # empty lists should be in the output + self.assertIn("output", event) + self.assertNotIn("instructions", event) # None fields should NOT be in the output + self.assertNotIn("metadata", event) + self.assertNotIn("error", event) # Unset optional fields should NOT be in the output + self.assertNotIn("top_p", event) def async_retry(fn, max_attempts=5, delay=2): @@ -105,7 +170,7 @@ class ServeCompletionsMixin: @async_retry async def run_server(self, request): - client = AsyncInferenceClient("http://localhost:8000") + client = AsyncInferenceClient(f"http://localhost:{self.port}") stream = client.chat_completion(**request) all_payloads = [] @@ -119,8 +184,7 @@ class ServeCompletionsMixin: [ ("default_request", {}), ("one_token", {"max_tokens": 1}), - # TODO: CB fails next case, seems like it is unable to switch models. fix me - # ("different_model", {"model": "HuggingFaceTB/SmolLM2-135M-Instruct"}), + ("different_model", {"model": "HuggingFaceTB/SmolLM2-135M-Instruct"}), ( "tool_call", { @@ -191,20 +255,20 @@ class ServeCompletionsMixin: # sets `do_sample=True` self.assertEqual(output_text, '\nOkay, the user just asked, "') - # TODO: implement API-compliant error handling, and then test it - # See https://platform.openai.com/docs/guides/error-codes, # TODO: one test for each request flag, to confirm it is working as expected # TODO: speed-based test to confirm that KV cache is working across requests -@slow # TODO (joao): this shouldn't be needed -class ServeCompletionsGenerateTest(ServeCompletionsMixin, unittest.TestCase): +@slow # server startup time is slow on our push CI +@require_openai +class ServeCompletionsGenerateIntegrationTest(ServeCompletionsMixin, unittest.TestCase): """Tests the `generate` version of the Completions API.""" @classmethod def setUpClass(cls): """Starts a server for tests to connect to.""" - args = ServeArguments() + cls.port = 8001 + args = ServeArguments(port=cls.port) serve_command = ServeCommand(args) thread = Thread(target=serve_command.run) thread.daemon = True @@ -287,15 +351,20 @@ class ServeCompletionsGenerateTest(ServeCompletionsMixin, unittest.TestCase): self.assertTrue(all(reason is None for reason in finish_reasons[:-1])) -@slow # TODO (joao): this shouldn't be needed -class ServeCompletionsContinuousBatchingTest(ServeCompletionsMixin, unittest.TestCase): +@slow # server startup time is slow on our push CI +@require_openai +class ServeCompletionsContinuousBatchingIntegrationTest(ServeCompletionsMixin, unittest.TestCase): """Tests the `continuous_batching` version of the Completions API.""" @classmethod def setUpClass(cls): """Starts a server for tests to connect to.""" - args = ServeArguments(attn_implementation="sdpa_paged") # important: toggle continuous batching + cls.port = 8002 + args = ServeArguments(port=cls.port, attn_implementation="sdpa_paged") # important: toggle continuous batching serve_command = ServeCommand(args) thread = Thread(target=serve_command.run) thread.daemon = True thread.start() + + +# TODO: Response integration tests